Teerth Patel
commited on
Commit
•
199a42f
1
Parent(s):
772112c
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .DS_Store +0 -0
- .gitignore +40 -0
- Dockerfile +18 -0
- app.py +77 -0
- benchmarks/CLRS/env/__init__.py +14 -0
- benchmarks/CLRS/env/baseline_model_description.txt +507 -0
- benchmarks/CLRS/env/baselines.py +794 -0
- benchmarks/CLRS/env/baselines_test.py +294 -0
- benchmarks/CLRS/env/data_description.txt +35 -0
- benchmarks/CLRS/env/dataset.py +326 -0
- benchmarks/CLRS/env/dataset_test.py +116 -0
- benchmarks/CLRS/env/decoders.py +381 -0
- benchmarks/CLRS/env/decoders_test.py +47 -0
- benchmarks/CLRS/env/encoders.py +139 -0
- benchmarks/CLRS/env/evaluation.py +202 -0
- benchmarks/CLRS/env/evaluation_test.py +55 -0
- benchmarks/CLRS/env/losses.py +209 -0
- benchmarks/CLRS/env/losses_test.py +166 -0
- benchmarks/CLRS/env/model.py +46 -0
- benchmarks/CLRS/env/nets.py +719 -0
- benchmarks/CLRS/env/probing.py +351 -0
- benchmarks/CLRS/env/probing_test.py +192 -0
- benchmarks/CLRS/env/processors.py +856 -0
- benchmarks/CLRS/env/processors_test.py +64 -0
- benchmarks/CLRS/env/samplers.py +882 -0
- benchmarks/CLRS/env/samplers_test.py +250 -0
- benchmarks/CLRS/env/specs.py +525 -0
- benchmarks/CLRS/env/train.py +560 -0
- benchmarks/CLRS/scripts/eval.py +454 -0
- benchmarks/CLRS/scripts/requirements.txt +13 -0
- benchmarks/CLRS/scripts/research_problem.txt +3 -0
- benchmarks/CLRS/scripts/source_code.txt +1 -0
- benchmarks/amp-parkinsons-disease-progression-prediction/env/data_description.txt +33 -0
- benchmarks/amp-parkinsons-disease-progression-prediction/env/evaluation_details.txt +12 -0
- benchmarks/amp-parkinsons-disease-progression-prediction/env/public_timeseries_testing_util.py +94 -0
- benchmarks/amp-parkinsons-disease-progression-prediction/env/train.py +141 -0
- benchmarks/amp-parkinsons-disease-progression-prediction/scripts/eval.py +21 -0
- benchmarks/amp-parkinsons-disease-progression-prediction/scripts/prepare.py +79 -0
- benchmarks/amp-parkinsons-disease-progression-prediction/scripts/read_only_files.txt +5 -0
- benchmarks/amp-parkinsons-disease-progression-prediction/scripts/research_problem.txt +3 -0
- benchmarks/amp-parkinsons-disease-progression-prediction/scripts/source_code.txt +2 -0
- benchmarks/babylm/env/babyLM_for_hf.py +104 -0
- benchmarks/babylm/env/train.py +641 -0
- benchmarks/babylm/scripts/eval.py +212 -0
- benchmarks/babylm/scripts/prepare.py +11 -0
- benchmarks/babylm/scripts/read_only_files.txt +2 -0
- benchmarks/babylm/scripts/research_problem.txt +7 -0
- benchmarks/bibtex-generation/env/arxiv_API_reference.txt +599 -0
- benchmarks/bibtex-generation/env/bibtex_generation.py +0 -0
- benchmarks/bibtex-generation/env/claude_example.py +11 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitignore
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*_api_key.txt
|
2 |
+
|
3 |
+
__pycache__
|
4 |
+
build
|
5 |
+
dist
|
6 |
+
/.venv/
|
7 |
+
.env
|
8 |
+
.vscode
|
9 |
+
.mypy_cache
|
10 |
+
.pytest_cache
|
11 |
+
*.pyc
|
12 |
+
flagged
|
13 |
+
wandb
|
14 |
+
logs
|
15 |
+
temp
|
16 |
+
tests/logs
|
17 |
+
tests/wandb
|
18 |
+
|
19 |
+
# Outputs generated by the cli demo
|
20 |
+
cached_generated_dataset/
|
21 |
+
generated_dataset/
|
22 |
+
huggingface_data/huggingface_datasets/dataset_index.json
|
23 |
+
huggingface_data/huggingface_datasets/huggingface_datasets_datafinder_index
|
24 |
+
huggingface_data/huggingface_datasets/reranking_dataset_index.json
|
25 |
+
huggingface_data/huggingface_models/
|
26 |
+
retrieved_dataset_dict/
|
27 |
+
result/
|
28 |
+
checkpoint/
|
29 |
+
status.yaml
|
30 |
+
# Outputs generated by the colab demo
|
31 |
+
trained_model/
|
32 |
+
trained_tokenizer/
|
33 |
+
|
34 |
+
|
35 |
+
# data
|
36 |
+
babylm_data
|
37 |
+
checkpoint-*/
|
38 |
+
|
39 |
+
# hf models info
|
40 |
+
huggingface_models/
|
Dockerfile
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM anibali/pytorch:2.0.0-cuda11.8-ubuntu22.04
|
2 |
+
|
3 |
+
# Set up time zone.
|
4 |
+
ENV TZ=UTC
|
5 |
+
RUN sudo ln -snf /usr/share/zoneinfo/$TZ /etc/localtime
|
6 |
+
|
7 |
+
USER root
|
8 |
+
RUN apt update && apt install -y gcc-10 g++-10 && ln /usr/bin/gcc-10 /usr/bin/gcc && ln /usr/bin/g++-10 /usr/bin/g++ && apt install -y zlib1g-dev && rm -r /var/lib/apt/lists/*
|
9 |
+
|
10 |
+
# copy files
|
11 |
+
WORKDIR /app
|
12 |
+
COPY . .
|
13 |
+
|
14 |
+
# Install libraries
|
15 |
+
RUN python3 -m pip install -r requirements.txt
|
16 |
+
|
17 |
+
# start bash shell
|
18 |
+
CMD bash
|
app.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from pathlib import Path
|
3 |
+
from reactagent.environment import Environment
|
4 |
+
from reactagent.agents.agent_research import ResearchAgent
|
5 |
+
from reactagent.runner import create_parser
|
6 |
+
from reactagent import llm
|
7 |
+
from reactagent.users.user import User
|
8 |
+
|
9 |
+
class SessionInfo:
|
10 |
+
def __init__(self):
|
11 |
+
self.coro_cache = {}
|
12 |
+
self.parser = create_parser()
|
13 |
+
|
14 |
+
def make_session(self, prompt, session_hash):
|
15 |
+
id = session_hash
|
16 |
+
|
17 |
+
llm_name='claude-3-5-sonnet-20240620'
|
18 |
+
fastllm_name='claude-3-haiku-20240307'
|
19 |
+
rawargs = [
|
20 |
+
'--research-problem', prompt,
|
21 |
+
'--log-dir', str(Path('logs', id)),
|
22 |
+
'--work-dir', str(Path('workspaces', id)),
|
23 |
+
'--llm-name', llm_name,
|
24 |
+
'--edit-script-llm-name', llm_name,
|
25 |
+
'--fast-llm-name', fastllm_name,
|
26 |
+
]
|
27 |
+
|
28 |
+
args = self.parser.parse_args(rawargs)
|
29 |
+
llm.FAST_MODEL = args.fast_llm_name
|
30 |
+
env = Environment(args)
|
31 |
+
agent = ResearchAgent(args, env)
|
32 |
+
coro = agent.run(env)
|
33 |
+
|
34 |
+
self.coro_cache[id] = coro
|
35 |
+
return id
|
36 |
+
|
37 |
+
def get_response(self, human_input, session_hash):
|
38 |
+
coro_input = human_input
|
39 |
+
if session_hash not in self.coro_cache:
|
40 |
+
self.make_session(human_input, session_hash)
|
41 |
+
coro_input = None
|
42 |
+
|
43 |
+
try:
|
44 |
+
output = self.coro_cache[session_hash].send(coro_input)
|
45 |
+
except StopIteration:
|
46 |
+
output = None
|
47 |
+
del self.coro_cache[session_hash]
|
48 |
+
|
49 |
+
return output
|
50 |
+
|
51 |
+
session_info = SessionInfo()
|
52 |
+
|
53 |
+
def info_to_message(info):
|
54 |
+
msg = ""
|
55 |
+
for k, v in info.items():
|
56 |
+
if isinstance(v, dict):
|
57 |
+
tempv = v
|
58 |
+
v = ""
|
59 |
+
for k2, v2 in tempv.items():
|
60 |
+
v += f"{k2}:\n {v2}\n"
|
61 |
+
v = User.indent_text(v, 2)
|
62 |
+
msg += '-' * 64
|
63 |
+
msg += '\n'
|
64 |
+
msg += f"{k}:\n{v}\n"
|
65 |
+
|
66 |
+
msg += "Please provide feedback based on the history, response entries, and observation, and questions: "
|
67 |
+
return msg
|
68 |
+
|
69 |
+
def predict(message, history, request: gr.Request):
|
70 |
+
response = session_info.get_response(message, request.session_hash)
|
71 |
+
if response is None:
|
72 |
+
response = "Agent is finished. Enter a new instruction."
|
73 |
+
return response
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
demo = gr.ChatInterface(predict)
|
77 |
+
demo.launch()
|
benchmarks/CLRS/env/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
benchmarks/CLRS/env/baseline_model_description.txt
ADDED
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The BaselineModel class in baselines.py file is a full working Graph Neural Network (GNN) example using JAX and the DeepMind JAX Ecosystem of libraries. It allows training of multiple algorithms on a single processor, as described in the paper "A Generalist Neural Algorithmic Learner" (arXiv:2209.11142v2 [cs.LG] 3 Dec 2022). Below is an excerpt from the paper that describes the model:
|
2 |
+
|
3 |
+
Each algorithm in the CLRS benchmark [5] is specified by a number of inputs, hints and outputs. In
|
4 |
+
a given sample, the inputs and outputs are fixed, while hints are time-series of intermediate states of
|
5 |
+
the algorithm. Each sample for a particular task has a size, n, corresponding to the number of nodes
|
6 |
+
in the GNN that will execute the algorithm.
|
7 |
+
A sample of every algorithm is represented as a graph, with each input, output and hint located in
|
8 |
+
either the nodes, the edges, or the graph itself, and therefore has shape (excluding batch dimension,
|
9 |
+
and, for hints, time dimension) n × f , n × n × f , or f , respectively, f being the dimensionality of
|
10 |
+
the feature, which depends on its type. The CLRS benchmark defines five types of features: scalar,
|
11 |
+
categorical, mask, mask_one and pointer, with their own encoding and decoding strategies and
|
12 |
+
loss functions—e.g. a scalar type will be encoded and decoded directly by a single linear layer, and
|
13 |
+
optimised using mean squared error.
|
14 |
+
|
15 |
+
Base Model
|
16 |
+
|
17 |
+
Encoder. We adopt the same encode-process-decode paradigm [33] presented with the CLRS
|
18 |
+
benchmark [5]. At each time step, t, of a particular task τ (e.g. insertion sort), the task-based encoder
|
19 |
+
fτ , consisting of a linear encoder for each input and hint, embeds inputs and the current hints as
|
20 |
+
high-dimensional vectors. These embeddings of inputs and hints located in the nodes all have the
|
21 |
+
same dimension and are added together; the same happens with hints and inputs located in edges,
|
22 |
+
and in the graph. In our experiments we use the same dimension, h = 128, for node, edge and graph
|
23 |
+
3
|
24 |
+
|
25 |
+
A Generalist Neural Algorithmic Learner
|
26 |
+
|
27 |
+
embeddings. Thus, at the
|
28 |
+
step for a time-step t of the algorithm, we have a
|
29 |
+
n end of the encoding
|
30 |
+
o
|
31 |
+
(t) (t)
|
32 |
+
(t)
|
33 |
+
single set of embeddings xi , eij , g
|
34 |
+
, shapes n × h, n × n × h, and h, in the nodes, edges and
|
35 |
+
graph, respectively. Note that this is independent of the number and type of the inputs and hints of
|
36 |
+
the particular algorithm, allowing us to share this latent space across all thirty algorithms in CLRS.
|
37 |
+
Further, note that at each step, the input encoding is fed directly to these embeddings—this recall
|
38 |
+
mechanism significantly improves the model’s robustness over long trajectories [34].
|
39 |
+
Processor. The embeddings are fed into a processor P , a GNN that performs one step of computation. The processor transforms the input node, edge and graph embeddings into processed
|
40 |
+
(t)
|
41 |
+
node embeddings, hi . Additionally, the processor uses the processed node embeddings from the
|
42 |
+
(t−1)
|
43 |
+
previous step, hi
|
44 |
+
, as inputs. Importantly, the same processor model can operate on graphs of any
|
45 |
+
size. We leverage the message-passing neural network [35, MPNN], using the max aggregation and
|
46 |
+
passing messages over a fully-connected graph, as our base model. The MPNN computes processed
|
47 |
+
embeddings as follows:
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
(t)
|
53 |
+
(t−1)
|
54 |
+
(t)
|
55 |
+
(t) (t) (t)
|
56 |
+
(t)
|
57 |
+
(t)
|
58 |
+
(t)
|
59 |
+
z(t) = xi khi
|
60 |
+
mi = max fm zi , zj , eij , g(t)
|
61 |
+
hi = fr zi , mi
|
62 |
+
(1)
|
63 |
+
1≤j≤n
|
64 |
+
|
65 |
+
starting from h(0) = 0. Here k denotes concatenation, fm : R2h × R2h × Rh × Rh → Rh is the
|
66 |
+
message function (for which we use a three-layer MLP with ReLU activations), and fr : R2h × Rh →
|
67 |
+
Rh is the readout function (for which we use a linear layer with ReLU activation). The use of the max
|
68 |
+
aggregator is well-motivated by prior work [5, 9], and we use the fully connected graph—letting the
|
69 |
+
neighbours j range over all nodes (1 ≤ j ≤ n)—in order to allow the model to overcome situations
|
70 |
+
(t)
|
71 |
+
where the input graph structure may be suboptimal. Layer normalisation [36] is applied to hi before
|
72 |
+
using them further. Further details on the MPNN processor may be found in Veličković et al. [5].
|
73 |
+
Decoder. The processed embeddings are finally decoded with a task-based decoder gτ , to predict
|
74 |
+
the hints for the next step, and the outputs at the final step. Akin to the encoder, the task-based decoder
|
75 |
+
relies mainly on a linear decoder for each hint and output, along with a mechanism to compute
|
76 |
+
pairwise node similarities when appropriate. Specifically, the pointer type decoder computes
|
77 |
+
a score, sij , for each pair of nodes, and then chooses the pointer of node i by taking either the
|
78 |
+
argmaxj sij or softmaxj sij (depending on whether a hard or soft prediction is used).
|
79 |
+
Loss. The decoded hints and outputs are used to compute the loss during training, according to their
|
80 |
+
type [5]. For each sample in a batch, the hint prediction losses are averaged across hints and time,
|
81 |
+
and the output loss is averaged across outputs (most algorithms have a single output, though some
|
82 |
+
have two outputs). The hint loss and output loss are added together. Besides, the hint predictions at
|
83 |
+
each time step are fed back as inputs for the next step, except possibly at train time if teacher forcing
|
84 |
+
is used (see Section 3.2.1).
|
85 |
+
We train the model on samples with sizes n ≤ 16, and periodically evaluate them on in-distribution
|
86 |
+
samples of size n = 16. Also, periodically, we evaluate the model with the best in-distribution
|
87 |
+
evaluation score so far on OOD samples of size n = 64. In what follows, we will be reporting only
|
88 |
+
these OOD evaluation scores. Full details of the model, training and evaluation hyperparameters can
|
89 |
+
be found in Appendix A.
|
90 |
+
3.2
|
91 |
+
|
92 |
+
Model improvements
|
93 |
+
|
94 |
+
As previously discussed, single-task improvements, especially in terms of learning stability, will
|
95 |
+
empirically transfer well to multi-task algorithmic learning. We now describe, in a gradual manner,
|
96 |
+
all the changes made to the model, which have lead to an absolute improvement of over 20% on
|
97 |
+
average across all 30 tasks in CLRS.
|
98 |
+
3.2.1
|
99 |
+
|
100 |
+
Dataset and training
|
101 |
+
|
102 |
+
Removing teacher forcing. At evaluation time, the model has no access to the step-by-step hints
|
103 |
+
in the dataset, and has to rely on its own hint predictions. However, during training, it is sometimes
|
104 |
+
advisable to stabilise the trajectories with teacher forcing [37]—providing the ground-truth hint
|
105 |
+
values instead of the network’s own predictions. In the prior model [5], ground-truth hints were
|
106 |
+
4
|
107 |
+
|
108 |
+
A Generalist Neural Algorithmic Learner
|
109 |
+
|
110 |
+
provided during training with probability 0.5, as, without teacher forcing, losses tended to grow
|
111 |
+
unbounded along a trajectory when scalar hints were present, destabilising the training. In this
|
112 |
+
work we incorporate several significant stabilising changes (described in future paragraphs), which
|
113 |
+
allows us to remove teacher forcing altogether, aligning training with evaluation, and avoiding the
|
114 |
+
network becoming overconfident in always expecting correct hint predictions. With teacher forcing,
|
115 |
+
performance deteriorates significantly in sorting algorithms and Kruskal’s algorithm. Naïve String
|
116 |
+
Matcher, on the other hand, improves with teacher forcing (see Appendix A, Figs. 7-9).
|
117 |
+
Augmenting the training data. To prevent our model from over-fitting to the statistics of the fixed
|
118 |
+
CLRS training dataset [5], we augmented the training data in three key ways, without breaking
|
119 |
+
the intended size distribution shift. Firstly, we used the on-line samplers in CLRS to generate new
|
120 |
+
training examples on the fly, rather than using a fixed dataset which is easier to overfit to. Secondly,
|
121 |
+
we trained on examples of mixed sizes, n ≤ 16, rather than only 16, which helps the model anticipate
|
122 |
+
for a diverse range of sizes, rather than overfitting to the specifics of size n = 16. Lastly, for graph
|
123 |
+
algorithms, we varied the connectivity probability p of the input graphs (generated by the Erdős-Rényi
|
124 |
+
model [38]); and for string matching algorithms, we varied the length of the pattern to be matched.
|
125 |
+
These both serve to expose the model to different trajectory lengths; for example, in many graph
|
126 |
+
algorithms, the amount of steps the algorithm should run for is related to the graph’s diameter, and
|
127 |
+
varying the connection probability in the graph generation allows for varying the expected diameter.
|
128 |
+
These changes considerably increase training data variability, compared to the original dataset in
|
129 |
+
Veličković et al. [5]. We provide a more detailed step-by-step overview of the data generation process
|
130 |
+
in Appendix A.
|
131 |
+
Soft hint propagation. When predicted hints are fed back as inputs during training, gradients
|
132 |
+
may or may not be allowed to flow through them. In previous work, only hints of the scalar type
|
133 |
+
allowed gradients through, as all categoricals were post-processed from logits into the ground-truth
|
134 |
+
format via argmax or thresholding before being fed back. Instead, in this work we use softmax
|
135 |
+
for categorical, mask_one and pointer types, and the logistic sigmoid for mask types. Without
|
136 |
+
these soft hints, performance in sorting algorithms degrades (similarly to the case of teacher forcing),
|
137 |
+
as well as in Naïve String Matcher (Appendix A, Figs. 7-9).
|
138 |
+
Static hint elimination. Eleven algorithms in CLRS3 specify a fixed ordering of the nodes, common
|
139 |
+
to every sample, via a node pointer hint that does not ever change along the trajectories. Prediction of
|
140 |
+
this hint is trivial (identity function), but poses a potential problem for OOD generalisation, since the
|
141 |
+
model can overfit to the fixed training values. We therefore turned this fixed hint into an input for
|
142 |
+
these 11 algorithms, eliminating the need for explicitly predicting it.
|
143 |
+
Improving training stability with encoder initialisation and gradient clipping. The scalar
|
144 |
+
hints have unbounded values, in principle, and are optimised using mean-squared error, hence their
|
145 |
+
gradients can quickly grow with increasing prediction error. Further, the predicted scalar hints then
|
146 |
+
get re-encoded at every step, which can rapidly amplify errors throughout the trajectory, leading to
|
147 |
+
exploding signals (and consequently gradients), even before any training takes place.
|
148 |
+
To rectify this issue, we use the Xavier initialisation [45], effectively reducing the initial weights for
|
149 |
+
scalar hints whose input dimensionality is just 1. However, we reverted to using the default LeCun
|
150 |
+
initialisation [46] elsewhere. This combination of initialisations proved important for the initial
|
151 |
+
learning stability of our model over long trajectories. Relatedly, in preliminary experiments, we saw
|
152 |
+
drastic improvements in learning stability, as well as significant increases in validation performance,
|
153 |
+
with gradient clipping [47], which we subsequently employed in all experiments.
|
154 |
+
3.2.2
|
155 |
+
|
156 |
+
Encoders and decoders
|
157 |
+
|
158 |
+
Randomised position scalar. Across all algorithms in the dataset, there exists a position scalar
|
159 |
+
input which uniquely indexes the nodes, with values linearly spaced between 0 and 1 along the node
|
160 |
+
index. To avoid overfitting to these linearly spaced values during training, we replaced them with
|
161 |
+
random values, uniformly sampled in [0, 1], sorted to match the initial order implied by the linearly
|
162 |
+
spaced values. The benefit of this change is notable in algorithms where it would be easy to overfit to
|
163 |
+
3
|
164 |
+
|
165 |
+
Binary Search, Minimum, Max Subarray [39], Matrix Chain Order, LCS Length, Optimal BST [40], Activity
|
166 |
+
Selector [41], Task Scheduling [42], Naïve String Matcher, Knuth-Morris-Pratt [43] and Jarvis’ March [44].
|
167 |
+
5
|
168 |
+
|
169 |
+
A Generalist Neural Algorithmic Learner
|
170 |
+
|
171 |
+
these positions, such as string matching. Namely, the model could learn to base all of its computations
|
172 |
+
on the assumption that it will always be finding a m-character pattern inside an n-character string,
|
173 |
+
even though at test time, m and n will increase fourfold.
|
174 |
+
Permutation decoders and the Sinkhorn operator. Sorting algorithms (Insertion Sort, Bubble
|
175 |
+
Sort, Heapsort [48] and Quicksort [49]) always output a permutation of the input nodes. In the CLRS
|
176 |
+
benchmark, this permutation is encoded as a pointer where each node points to its predecessor in
|
177 |
+
the sorted order (the first node points to itself); this is represented as a n × n matrix P where each
|
178 |
+
row is a one-hot vector, such that element (i, j) is 1 if node i points to node j. As with all types of
|
179 |
+
pointers, such permutation pointers can be predicted using a row-wise softmax on unconstrained
|
180 |
+
decoder outputs (logits), trained with cross entropy (as in Veličković et al. [5]). However, this does
|
181 |
+
not explicitly take advantage of the fact that the pointers encode a permutation, which the model
|
182 |
+
has to learn instead. Our early experiments showed that the model was often failing to predict valid
|
183 |
+
permutations OOD.
|
184 |
+
Accordingly, we enforce a permutation inductive bias in the output decoder of sorting algorithms, as
|
185 |
+
follows. First, we modify the output representation by rewiring the first node to point to the last one,
|
186 |
+
turning P into a permutation matrix, i.e., a matrix whose rows and columns are one-hot vectors. We
|
187 |
+
also augment the representation with a one-hot vector of size n that specifies the first node, so we do
|
188 |
+
not lose this information; this vector is treated like a regular mask_one feature. Second, we predict the
|
189 |
+
permutation matrix P from unconstrained decoder outputs Y by replacing the usual row-wise softmax
|
190 |
+
with the Sinkhorn operator S [32, 50–53]. S projects an arbitrary square matrix Y into a doubly
|
191 |
+
stochastic matrix S(Y) (a non-negative matrix whose rows and columns sum to 1), by exponentiating
|
192 |
+
and repeatedly normalizing rows and columns so they sum to 1. Specifically, S is defined by:
|
193 |
+
S 0 (Y) = exp(Y)
|
194 |
+
|
195 |
+
S l (Y) = Tc (Tr (S l−1 (Y)))
|
196 |
+
|
197 |
+
S(Y) = lim S l (Y),
|
198 |
+
l→∞
|
199 |
+
|
200 |
+
(2)
|
201 |
+
|
202 |
+
where exp acts element-wise, and Tr and Tc denote row and column normalisation respectively.
|
203 |
+
Although the Sinkhorn operator produces a doubly stochastic matrix rather than a permutation matrix,
|
204 |
+
we can obtain a permutation matrix by introducing a temperature parameter, τ > 0, and taking
|
205 |
+
P = limτ →0+ S(Y/τ ); as long as there are no ties in the elements of Y, P is guaranteed to be a
|
206 |
+
permutation matrix [52, Theorem 1].
|
207 |
+
In practice, we compute the Sinkhorn operator using a fixed number of iterations lmax . We use a
|
208 |
+
smaller number of iterations lmax = 10 for training, to limit vanishing and exploding gradients, and
|
209 |
+
lmax = 60 for evaluation. A fixed temperature τ = 0.1 was experimentally found to give a good
|
210 |
+
balance between speed of convergence and tie-breaking. We also encode the fact that no node points
|
211 |
+
to itself, that is, that all diagonal elements of P should be 0, by setting the diagonal elements of Y to
|
212 |
+
−∞. To avoid ties, we follow Mena et al. [53], injecting Gumbel noise to the elements of Y prior to
|
213 |
+
applying the Sinkhorn operator, during training only. Finally, we transform the predicted matrix P,
|
214 |
+
and mask_one pointing to the first element, into the original pointer representation used by CLRS.
|
215 |
+
3.2.3
|
216 |
+
|
217 |
+
Processor networks
|
218 |
+
|
219 |
+
Gating mechanisms. Many algorithms only require updating a few nodes at each time step, keeping
|
220 |
+
the rest unchanged. However, the MPNN we use (Equation 1) is biased towards the opposite: it
|
221 |
+
updates all hidden states in each step. Although it is theoretically possible for the network to keep the
|
222 |
+
states unchanged, learning to do so is not easy. With this in mind, and motivated by its effectiveness
|
223 |
+
in NDRs [54], we augment the network with an update gate, biased to be closed by default. We
|
224 |
+
found that the gate stabilizes learning on many of the tasks, and increases the mean performance
|
225 |
+
over all tasks on single-task training significantly. Surprisingly, however, we did not find gating to be
|
226 |
+
advantageous in the multi-task case.
|
227 |
+
To add gating to the MPNN model we produce a per-node gating vector from the same inputs that
|
228 |
+
process the embeddings in Equation 1:
|
229 |
+
|
230 |
+
|
231 |
+
(t)
|
232 |
+
(t)
|
233 |
+
(t)
|
234 |
+
gi = fg zi , mi
|
235 |
+
(3)
|
236 |
+
where fg : R2h × Rh → Rh is the gating function, for which we use a two-layer MLP, with
|
237 |
+
ReLU activation for the hidden layer and logistic sigmoid activation for the output. Importantly, the
|
238 |
+
final layer bias of fg is initialized to a value of −3, which biases the network for not updating its
|
239 |
+
6
|
240 |
+
|
241 |
+
A Generalist Neural Algorithmic Learner
|
242 |
+
|
243 |
+
Our model
|
244 |
+
Previous SOTA [5]
|
245 |
+
|
246 |
+
80
|
247 |
+
60
|
248 |
+
40
|
249 |
+
|
250 |
+
Quickselect
|
251 |
+
|
252 |
+
Heapsort
|
253 |
+
|
254 |
+
Knuth-Morris-Pratt
|
255 |
+
|
256 |
+
Strongly Conn. Comps.
|
257 |
+
|
258 |
+
DFS
|
259 |
+
|
260 |
+
Floyd-Warshall
|
261 |
+
|
262 |
+
Quicksort
|
263 |
+
|
264 |
+
Bubble Sort
|
265 |
+
|
266 |
+
Optimal BST
|
267 |
+
|
268 |
+
Find Max. Subarray
|
269 |
+
|
270 |
+
Insertion Sort
|
271 |
+
|
272 |
+
Binary Search
|
273 |
+
|
274 |
+
LCS Length
|
275 |
+
|
276 |
+
Naïve String Matcher
|
277 |
+
|
278 |
+
MST Prim
|
279 |
+
|
280 |
+
Topological Sort
|
281 |
+
|
282 |
+
Task Scheduling
|
283 |
+
|
284 |
+
MST Kruskal
|
285 |
+
|
286 |
+
Articulation Points
|
287 |
+
|
288 |
+
Jarvis' March
|
289 |
+
|
290 |
+
Matrix Chain Order
|
291 |
+
|
292 |
+
Bridges
|
293 |
+
|
294 |
+
Graham Scan
|
295 |
+
|
296 |
+
Dijkstra
|
297 |
+
|
298 |
+
Activity Selector
|
299 |
+
|
300 |
+
Bellman-Ford
|
301 |
+
|
302 |
+
DAG Shortest Paths
|
303 |
+
|
304 |
+
Segments Intersect
|
305 |
+
|
306 |
+
0
|
307 |
+
|
308 |
+
BFS
|
309 |
+
|
310 |
+
20
|
311 |
+
Minimum
|
312 |
+
|
313 |
+
Average score [%]
|
314 |
+
|
315 |
+
100
|
316 |
+
|
317 |
+
Figure 2: The OOD performance in single-task experiments before and after the improvements
|
318 |
+
presented in this paper, sorted in descending order of current performance. Error bars represent
|
319 |
+
standard error of the mean across seeds (3 seeds for previous SOTA experiments, 10 seeds for current).
|
320 |
+
The previous SOTA values are the best of MPNN, PGN and Memnet models (see Table 2).
|
321 |
+
b (t) , are computed as follows:
|
322 |
+
representations, unless necessary. The processed gated embeddings, h
|
323 |
+
i
|
324 |
+
b (t) = g(t)
|
325 |
+
h
|
326 |
+
i
|
327 |
+
i
|
328 |
+
and are used instead of
|
329 |
+
|
330 |
+
(t)
|
331 |
+
hi
|
332 |
+
|
333 |
+
(t)
|
334 |
+
|
335 |
+
(t)
|
336 |
+
|
337 |
+
hi + (1 − gi )
|
338 |
+
|
339 |
+
in the subsequent steps, replacing z
|
340 |
+
|
341 |
+
(t−1)
|
342 |
+
|
343 |
+
hi
|
344 |
+
(t)
|
345 |
+
|
346 |
+
(4)
|
347 |
+
|
348 |
+
in Eq. 1 by z
|
349 |
+
|
350 |
+
(t)
|
351 |
+
|
352 |
+
=
|
353 |
+
|
354 |
+
(t) b (t−1)
|
355 |
+
xi kh
|
356 |
+
.
|
357 |
+
i
|
358 |
+
|
359 |
+
Triplet reasoning. Several algorithms within CLRS-30 explicitly require edge-based reasoning—
|
360 |
+
where edges store values, and update them based on other edges’ values. An example of this is the
|
361 |
+
Floyd-Warshall algorithm [55], which computes all-pairs shortest paths in a weighted graph. The
|
362 |
+
update rule for dij , its estimate for the best distance from node i to j, is dij = mink dik + dkj , which
|
363 |
+
roughly says “the best way to get from i to j is to find the optimal mid-point k, travel from i to k, then
|
364 |
+
from k to j”. Similar rules are pervasive across many CLRS-30 algorithms, especially in dynamic
|
365 |
+
programming. Even though there are no node representations in the above update, all our processors
|
366 |
+
are centered on passing messages between node representations hi .
|
367 |
+
To rectify this situation, we augment our processor to perform message passing towards edges.
|
368 |
+
Referring again to the update for dij , we note that the edge representations are updated by choosing
|
369 |
+
an intermediate node, then aggregating over all possible choices. Accordingly, and as previously observed by Dudzik and Veličković [31], we introduce triplet reasoning: first, computing representations
|
370 |
+
over triplets of nodes, then reducing over one node to obtain edge latents:
|
371 |
+
tijk = ψt (hi , hj , hk , eij , eik , ekj , g)
|
372 |
+
hij = φt (max tijk )
|
373 |
+
(5)
|
374 |
+
k
|
375 |
+
|
376 |
+
Here, ψt is a triplet message function, mapping all relevant representations to a single vector for
|
377 |
+
each triplet of nodes, and φt is an edge readout function, which transforms the aggregated triplets
|
378 |
+
for each edge for later use. According to prior findings on the CLRS benchmark [5], we use the
|
379 |
+
max aggregation to obtain edge representations. The computed hij vectors can then be used in any
|
380 |
+
edge-based reasoning task, and empirically they are indeed significantly beneficial, even in tasks
|
381 |
+
where we did not initially anticipate such benefits. One example is Kruskal’s minimum spanning tree
|
382 |
+
algorithm [56], where we presume that access to triplet reasoning allowed the model to more easily
|
383 |
+
sort the edges by weight, as it selects how to augment the spanning forest at each step.
|
384 |
+
In order to keep the footprint of triplet embeddings as lightweight as possible, we compute only
|
385 |
+
8-dimensional features in ψt . φt then upscales the aggregated edge features back to 128 dimensions,
|
386 |
+
to make them compatible with the rest of the architecture. Our initial experimentation demonstrated
|
387 |
+
that the output dimensionality of ψt did not significantly affect downstream performance. Note that
|
388 |
+
computing triplet representations has been a useful approach in general GNN design [57]—however,
|
389 |
+
it has predominantly been studied in the context of GNNs over constant input features. Our study is
|
390 |
+
among the first to verify their utility over reasoning tasks with well-specified initial features.
|
391 |
+
3.3
|
392 |
+
|
393 |
+
Results
|
394 |
+
|
395 |
+
By incorporating the changes described in the previous sections we arrived at a single model type,
|
396 |
+
with a single set of hyper-parameters, that was trained to reach new state-of-the-art performance
|
397 |
+
7
|
398 |
+
|
399 |
+
A Generalist Neural Algorithmic Learner
|
400 |
+
|
401 |
+
Table 1: Single-task OOD micro-F1 score of previous SOTA Memnet, MPNN and PGN [5] and our
|
402 |
+
best model Triplet-GMPNN with all our improvements, after 10,000 training steps.
|
403 |
+
Alg. Type
|
404 |
+
|
405 |
+
Memnet [5]
|
406 |
+
|
407 |
+
MPNN [5]
|
408 |
+
|
409 |
+
PGN [5]
|
410 |
+
|
411 |
+
Triplet-GMPNN (ours)
|
412 |
+
|
413 |
+
Div. & C.
|
414 |
+
DP
|
415 |
+
Geometry
|
416 |
+
Graphs
|
417 |
+
Greedy
|
418 |
+
Search
|
419 |
+
Sorting
|
420 |
+
Strings
|
421 |
+
|
422 |
+
13.05% ± 0.14
|
423 |
+
67.94% ± 8.20
|
424 |
+
45.14% ± 11.95
|
425 |
+
24.12% ± 5.30
|
426 |
+
53.42% ± 20.82
|
427 |
+
34.35% ± 21.67
|
428 |
+
71.53% ± 1.41
|
429 |
+
1.51% ± 0.46
|
430 |
+
|
431 |
+
20.30% ± 0.85
|
432 |
+
65.10% ± 6.44
|
433 |
+
73.11% ± 17.19
|
434 |
+
62.79% ± 8.75
|
435 |
+
82.39% ± 3.01
|
436 |
+
41.20% ± 19.87
|
437 |
+
11.83% ± 2.78
|
438 |
+
3.21% ± 0.94
|
439 |
+
|
440 |
+
65.23% ± 4.44
|
441 |
+
70.58% ± 6.48
|
442 |
+
61.19% ± 7.01
|
443 |
+
60.25% ± 8.42
|
444 |
+
75.84% ± 6.59
|
445 |
+
56.11% ± 21.56
|
446 |
+
15.45% ± 8.46
|
447 |
+
2.04% ± 0.20
|
448 |
+
|
449 |
+
76.36% ± 1.34
|
450 |
+
81.99% ± 4.98
|
451 |
+
94.09% ± 2.30
|
452 |
+
81.41% ± 6.21
|
453 |
+
91.21% ± 2.95
|
454 |
+
58.61% ± 24.34
|
455 |
+
60.37% ± 12.16
|
456 |
+
49.09% ± 23.49
|
457 |
+
|
458 |
+
38.88%
|
459 |
+
|
460 |
+
44.99%
|
461 |
+
|
462 |
+
50.84%
|
463 |
+
|
464 |
+
74.14%
|
465 |
+
|
466 |
+
0/30
|
467 |
+
3/30
|
468 |
+
10/30
|
469 |
+
|
470 |
+
6/30
|
471 |
+
9/30
|
472 |
+
14/30
|
473 |
+
|
474 |
+
3/30
|
475 |
+
7/30
|
476 |
+
15/30
|
477 |
+
|
478 |
+
11/30
|
479 |
+
17/30
|
480 |
+
24/30
|
481 |
+
|
482 |
+
Overall avg.
|
483 |
+
> 90%
|
484 |
+
> 80%
|
485 |
+
> 60%
|
486 |
+
|
487 |
+
on CLRS-30 [5]. Tables 1 and 2 show the micro-F1 scores of our model, which we refer to as
|
488 |
+
Triplet-GMPNN (an MPNN with gating and triplet edge processing), over the original CLRS-30 test
|
489 |
+
set (computed identically to Veličković et al. [5], but with 10 repetitions instead of 3). Our baselines
|
490 |
+
include the Memnet [58], MPNN [35] and PGN [59] models, taken directly from Veličković et al. [5].
|
491 |
+
Figure 2 displays the comparison between the improved model and the best model from Veličković
|
492 |
+
et al. [5]. Our improvements lead to an overall average performance that is more than 20% higher
|
493 |
+
(in absolute terms) compared to the next best model (see Table 1), and to a significant performance
|
494 |
+
improvement in all but one algorithm family, compared to every other model. Further, our stabilising
|
495 |
+
changes (such as gradient clipping) have empirically reduced the scale of our model’s gradient
|
496 |
+
updates across the 30 tasks, preparing us better for the numerical issues of the multi-task regime. We
|
497 |
+
finally also note that though we do not show it in Tables 1 & 2, applying the same improvements to
|
498 |
+
the PGN processor, leads to an increase in overall performance from 50.84% (Table 1) to 69.31%.
|
499 |
+
There are two notable examples of algorithm families with significant OOD performance improvement.
|
500 |
+
The first are geometric algorithms (Segments Intersect, Graham Scan [60] and Jarvis’ March), now
|
501 |
+
solved at approximately 94% OOD, compared to the previous best of about 73%; the second being
|
502 |
+
string algorithms (Knuth-Morris-Pratt and Naïve String Matcher) for which our model now exceeds
|
503 |
+
49% compared to the previous best of approximately 3%.
|
504 |
+
The significant overall performance boost is reflected in the increased number of algorithms we can
|
505 |
+
now solve at over 60%, 80% & 90% OOD performance, compared to previous SOTA [5]. Specifically,
|
506 |
+
we now exceed 60% accuracy in 24 algorithms (15 algorithms previously), 80% for 17 algorithms (9
|
507 |
+
previously) and 90% for 11 algorithms (6 previously).
|
benchmarks/CLRS/env/baselines.py
ADDED
@@ -0,0 +1,794 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""JAX implementation of CLRS baseline models."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
import os
|
20 |
+
import pickle
|
21 |
+
from typing import Dict, List, Optional, Tuple, Union
|
22 |
+
|
23 |
+
import chex
|
24 |
+
|
25 |
+
from clrs._src import decoders
|
26 |
+
from clrs._src import losses
|
27 |
+
from clrs._src import model
|
28 |
+
from clrs._src import nets
|
29 |
+
from clrs._src import probing
|
30 |
+
from clrs._src import processors
|
31 |
+
from clrs._src import samplers
|
32 |
+
from clrs._src import specs
|
33 |
+
|
34 |
+
import haiku as hk
|
35 |
+
import jax
|
36 |
+
import jax.numpy as jnp
|
37 |
+
import numpy as np
|
38 |
+
import optax
|
39 |
+
|
40 |
+
|
41 |
+
_Array = chex.Array
|
42 |
+
_DataPoint = probing.DataPoint
|
43 |
+
_Features = samplers.Features
|
44 |
+
_FeaturesChunked = samplers.FeaturesChunked
|
45 |
+
_Feedback = samplers.Feedback
|
46 |
+
_Location = specs.Location
|
47 |
+
_Seed = jnp.integer
|
48 |
+
_Spec = specs.Spec
|
49 |
+
_Stage = specs.Stage
|
50 |
+
_Trajectory = samplers.Trajectory
|
51 |
+
_Type = specs.Type
|
52 |
+
_OutputClass = specs.OutputClass
|
53 |
+
|
54 |
+
# pytype: disable=signature-mismatch
|
55 |
+
|
56 |
+
|
57 |
+
def _maybe_pick_first_pmapped(tree):
|
58 |
+
if jax.local_device_count() == 1:
|
59 |
+
return tree
|
60 |
+
return jax.tree_util.tree_map(lambda x: x[0], tree)
|
61 |
+
|
62 |
+
|
63 |
+
@jax.jit
|
64 |
+
def _restack_from_pmap(tree):
|
65 |
+
"""Stack the results of a pmapped computation across the first two axes."""
|
66 |
+
restack_array = lambda x: jnp.reshape(x, (-1,) + x.shape[2:])
|
67 |
+
return jax.tree_util.tree_map(restack_array, tree)
|
68 |
+
|
69 |
+
|
70 |
+
def _maybe_restack_from_pmap(tree):
|
71 |
+
if jax.local_device_count() == 1:
|
72 |
+
return tree
|
73 |
+
return _restack_from_pmap(tree)
|
74 |
+
|
75 |
+
|
76 |
+
@functools.partial(jax.jit, static_argnums=[1, 2])
|
77 |
+
def _pmap_reshape(x, n_devices, split_axis=0):
|
78 |
+
"""Splits a pytree over n_devices on axis split_axis for pmapping."""
|
79 |
+
def _reshape(arr):
|
80 |
+
new_shape = (arr.shape[:split_axis] +
|
81 |
+
(n_devices, arr.shape[split_axis] // n_devices) +
|
82 |
+
arr.shape[split_axis + 1:])
|
83 |
+
return jnp.moveaxis(jnp.reshape(arr, new_shape), split_axis, 0)
|
84 |
+
return jax.tree_util.tree_map(_reshape, x)
|
85 |
+
|
86 |
+
|
87 |
+
def _maybe_pmap_reshape(x, split_axis=0):
|
88 |
+
n_devices = jax.local_device_count()
|
89 |
+
if n_devices == 1:
|
90 |
+
return x
|
91 |
+
return _pmap_reshape(x, n_devices, split_axis)
|
92 |
+
|
93 |
+
|
94 |
+
@functools.partial(jax.jit, static_argnums=1)
|
95 |
+
def _pmap_data(data: Union[_Feedback, _Features], n_devices: int):
|
96 |
+
"""Replicate/split feedback or features for pmapping."""
|
97 |
+
if isinstance(data, _Feedback):
|
98 |
+
features = data.features
|
99 |
+
else:
|
100 |
+
features = data
|
101 |
+
pmap_data = features._replace(
|
102 |
+
inputs=_pmap_reshape(features.inputs, n_devices),
|
103 |
+
hints=_pmap_reshape(features.hints, n_devices, split_axis=1),
|
104 |
+
lengths=_pmap_reshape(features.lengths, n_devices),
|
105 |
+
)
|
106 |
+
if isinstance(data, _Feedback):
|
107 |
+
pmap_data = data._replace(
|
108 |
+
features=pmap_data,
|
109 |
+
outputs=_pmap_reshape(data.outputs, n_devices)
|
110 |
+
)
|
111 |
+
return pmap_data
|
112 |
+
|
113 |
+
|
114 |
+
def _maybe_pmap_data(data: Union[_Feedback, _Features]):
|
115 |
+
n_devices = jax.local_device_count()
|
116 |
+
if n_devices == 1:
|
117 |
+
return data
|
118 |
+
return _pmap_data(data, n_devices)
|
119 |
+
|
120 |
+
|
121 |
+
def _maybe_put_replicated(tree):
|
122 |
+
if jax.local_device_count() == 1:
|
123 |
+
return jax.device_put(tree)
|
124 |
+
else:
|
125 |
+
return jax.device_put_replicated(tree, jax.local_devices())
|
126 |
+
|
127 |
+
|
128 |
+
def _maybe_pmap_rng_key(rng_key: _Array):
|
129 |
+
n_devices = jax.local_device_count()
|
130 |
+
if n_devices == 1:
|
131 |
+
return rng_key
|
132 |
+
pmap_rng_keys = jax.random.split(rng_key, n_devices)
|
133 |
+
return jax.device_put_sharded(list(pmap_rng_keys), jax.local_devices())
|
134 |
+
|
135 |
+
|
136 |
+
class BaselineModel(model.Model):
|
137 |
+
"""Model implementation with selectable message passing algorithm."""
|
138 |
+
|
139 |
+
def __init__(
|
140 |
+
self,
|
141 |
+
spec: Union[_Spec, List[_Spec]],
|
142 |
+
dummy_trajectory: Union[List[_Feedback], _Feedback],
|
143 |
+
processor_factory: processors.ProcessorFactory,
|
144 |
+
hidden_dim: int = 32,
|
145 |
+
encode_hints: bool = False,
|
146 |
+
decode_hints: bool = True,
|
147 |
+
encoder_init: str = 'default',
|
148 |
+
use_lstm: bool = False,
|
149 |
+
learning_rate: float = 0.005,
|
150 |
+
grad_clip_max_norm: float = 0.0,
|
151 |
+
checkpoint_path: str = '/tmp/clrs3',
|
152 |
+
freeze_processor: bool = False,
|
153 |
+
dropout_prob: float = 0.0,
|
154 |
+
hint_teacher_forcing: float = 0.0,
|
155 |
+
hint_repred_mode: str = 'soft',
|
156 |
+
name: str = 'base_model',
|
157 |
+
nb_msg_passing_steps: int = 1,
|
158 |
+
):
|
159 |
+
"""Constructor for BaselineModel.
|
160 |
+
|
161 |
+
The model consists of encoders, processor and decoders. It can train
|
162 |
+
and evaluate either a single algorithm or a set of algorithms; in the
|
163 |
+
latter case, a single processor is shared among all the algorithms, while
|
164 |
+
the encoders and decoders are separate for each algorithm.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
spec: Either a single spec for one algorithm, or a list of specs for
|
168 |
+
multiple algorithms to be trained and evaluated.
|
169 |
+
dummy_trajectory: Either a single feedback batch, in the single-algorithm
|
170 |
+
case, or a list of feedback batches, in the multi-algorithm case, that
|
171 |
+
comply with the `spec` (or list of specs), to initialize network size.
|
172 |
+
processor_factory: A callable that takes an `out_size` parameter
|
173 |
+
and returns a processor (see `processors.py`).
|
174 |
+
hidden_dim: Size of the hidden state of the model, i.e., size of the
|
175 |
+
message-passing vectors.
|
176 |
+
encode_hints: Whether to provide hints as model inputs.
|
177 |
+
decode_hints: Whether to provide hints as model outputs.
|
178 |
+
encoder_init: The initialiser type to use for the encoders.
|
179 |
+
use_lstm: Whether to insert an LSTM after message passing.
|
180 |
+
learning_rate: Learning rate for training.
|
181 |
+
grad_clip_max_norm: if greater than 0, the maximum norm of the gradients.
|
182 |
+
checkpoint_path: Path for loading/saving checkpoints.
|
183 |
+
freeze_processor: If True, the processor weights will be frozen and
|
184 |
+
only encoders and decoders (and, if used, the lstm) will be trained.
|
185 |
+
dropout_prob: Dropout rate in the message-passing stage.
|
186 |
+
hint_teacher_forcing: Probability of using ground-truth hints instead
|
187 |
+
of predicted hints as inputs during training (only relevant if
|
188 |
+
`encode_hints`=True)
|
189 |
+
hint_repred_mode: How to process predicted hints when fed back as inputs.
|
190 |
+
Only meaningful when `encode_hints` and `decode_hints` are True.
|
191 |
+
Options are:
|
192 |
+
- 'soft', where we use softmaxes for categoricals, pointers
|
193 |
+
and mask_one, and sigmoids for masks. This will allow gradients
|
194 |
+
to flow through hints during training.
|
195 |
+
- 'hard', where we use argmax instead of softmax, and hard
|
196 |
+
thresholding of masks. No gradients will go through the hints
|
197 |
+
during training; even for scalar hints, which don't have any
|
198 |
+
kind of post-processing, gradients will be stopped.
|
199 |
+
- 'hard_on_eval', which is soft for training and hard for evaluation.
|
200 |
+
name: Model name.
|
201 |
+
nb_msg_passing_steps: Number of message passing steps per hint.
|
202 |
+
|
203 |
+
Raises:
|
204 |
+
ValueError: if `encode_hints=True` and `decode_hints=False`.
|
205 |
+
"""
|
206 |
+
super(BaselineModel, self).__init__(spec=spec)
|
207 |
+
|
208 |
+
if encode_hints and not decode_hints:
|
209 |
+
raise ValueError('`encode_hints=True`, `decode_hints=False` is invalid.')
|
210 |
+
|
211 |
+
assert hint_repred_mode in ['soft', 'hard', 'hard_on_eval']
|
212 |
+
|
213 |
+
self.decode_hints = decode_hints
|
214 |
+
self.checkpoint_path = checkpoint_path
|
215 |
+
self.name = name
|
216 |
+
self._freeze_processor = freeze_processor
|
217 |
+
if grad_clip_max_norm != 0.0:
|
218 |
+
optax_chain = [optax.clip_by_global_norm(grad_clip_max_norm),
|
219 |
+
optax.scale_by_adam(),
|
220 |
+
optax.scale(-learning_rate)]
|
221 |
+
self.opt = optax.chain(*optax_chain)
|
222 |
+
else:
|
223 |
+
self.opt = optax.adam(learning_rate)
|
224 |
+
|
225 |
+
self.nb_msg_passing_steps = nb_msg_passing_steps
|
226 |
+
|
227 |
+
self.nb_dims = []
|
228 |
+
if isinstance(dummy_trajectory, _Feedback):
|
229 |
+
assert len(self._spec) == 1
|
230 |
+
dummy_trajectory = [dummy_trajectory]
|
231 |
+
for traj in dummy_trajectory:
|
232 |
+
nb_dims = {}
|
233 |
+
for inp in traj.features.inputs:
|
234 |
+
nb_dims[inp.name] = inp.data.shape[-1]
|
235 |
+
for hint in traj.features.hints:
|
236 |
+
nb_dims[hint.name] = hint.data.shape[-1]
|
237 |
+
for outp in traj.outputs:
|
238 |
+
nb_dims[outp.name] = outp.data.shape[-1]
|
239 |
+
self.nb_dims.append(nb_dims)
|
240 |
+
|
241 |
+
self._create_net_fns(hidden_dim, encode_hints, processor_factory, use_lstm,
|
242 |
+
encoder_init, dropout_prob, hint_teacher_forcing,
|
243 |
+
hint_repred_mode)
|
244 |
+
self._device_params = None
|
245 |
+
self._device_opt_state = None
|
246 |
+
self.opt_state_skeleton = None
|
247 |
+
|
248 |
+
def _create_net_fns(self, hidden_dim, encode_hints, processor_factory,
|
249 |
+
use_lstm, encoder_init, dropout_prob,
|
250 |
+
hint_teacher_forcing, hint_repred_mode):
|
251 |
+
def _use_net(*args, **kwargs):
|
252 |
+
return nets.Net(self._spec, hidden_dim, encode_hints, self.decode_hints,
|
253 |
+
processor_factory, use_lstm, encoder_init,
|
254 |
+
dropout_prob, hint_teacher_forcing,
|
255 |
+
hint_repred_mode,
|
256 |
+
self.nb_dims, self.nb_msg_passing_steps)(*args, **kwargs)
|
257 |
+
|
258 |
+
self.net_fn = hk.transform(_use_net)
|
259 |
+
pmap_args = dict(axis_name='batch', devices=jax.local_devices())
|
260 |
+
n_devices = jax.local_device_count()
|
261 |
+
func, static_arg, extra_args = (
|
262 |
+
(jax.jit, 'static_argnums', {}) if n_devices == 1 else
|
263 |
+
(jax.pmap, 'static_broadcasted_argnums', pmap_args))
|
264 |
+
pmean = functools.partial(jax.lax.pmean, axis_name='batch')
|
265 |
+
self._maybe_pmean = pmean if n_devices > 1 else lambda x: x
|
266 |
+
extra_args[static_arg] = 3
|
267 |
+
self.jitted_grad = func(self._compute_grad, **extra_args)
|
268 |
+
extra_args[static_arg] = 4
|
269 |
+
self.jitted_feedback = func(self._feedback, donate_argnums=[0, 3],
|
270 |
+
**extra_args)
|
271 |
+
extra_args[static_arg] = [3, 4, 5]
|
272 |
+
self.jitted_predict = func(self._predict, **extra_args)
|
273 |
+
extra_args[static_arg] = [3, 4]
|
274 |
+
self.jitted_accum_opt_update = func(accum_opt_update, donate_argnums=[0, 2],
|
275 |
+
**extra_args)
|
276 |
+
|
277 |
+
def init(self, features: Union[_Features, List[_Features]], seed: _Seed):
|
278 |
+
if not isinstance(features, list):
|
279 |
+
assert len(self._spec) == 1
|
280 |
+
features = [features]
|
281 |
+
self.params = self.net_fn.init(jax.random.PRNGKey(seed), features, True, # pytype: disable=wrong-arg-types # jax-ndarray
|
282 |
+
algorithm_index=-1,
|
283 |
+
return_hints=False,
|
284 |
+
return_all_outputs=False)
|
285 |
+
self.opt_state = self.opt.init(self.params)
|
286 |
+
# We will use the optimizer state skeleton for traversal when we
|
287 |
+
# want to avoid updating the state of params of untrained algorithms.
|
288 |
+
self.opt_state_skeleton = self.opt.init(jnp.zeros(1))
|
289 |
+
|
290 |
+
@property
|
291 |
+
def params(self):
|
292 |
+
if self._device_params is None:
|
293 |
+
return None
|
294 |
+
return jax.device_get(_maybe_pick_first_pmapped(self._device_params))
|
295 |
+
|
296 |
+
@params.setter
|
297 |
+
def params(self, params):
|
298 |
+
self._device_params = _maybe_put_replicated(params)
|
299 |
+
|
300 |
+
@property
|
301 |
+
def opt_state(self):
|
302 |
+
if self._device_opt_state is None:
|
303 |
+
return None
|
304 |
+
return jax.device_get(_maybe_pick_first_pmapped(self._device_opt_state))
|
305 |
+
|
306 |
+
@opt_state.setter
|
307 |
+
def opt_state(self, opt_state):
|
308 |
+
self._device_opt_state = _maybe_put_replicated(opt_state)
|
309 |
+
|
310 |
+
def _compute_grad(self, params, rng_key, feedback, algorithm_index):
|
311 |
+
lss, grads = jax.value_and_grad(self._loss)(
|
312 |
+
params, rng_key, feedback, algorithm_index)
|
313 |
+
return self._maybe_pmean(lss), self._maybe_pmean(grads)
|
314 |
+
|
315 |
+
def _feedback(self, params, rng_key, feedback, opt_state, algorithm_index):
|
316 |
+
lss, grads = jax.value_and_grad(self._loss)(
|
317 |
+
params, rng_key, feedback, algorithm_index)
|
318 |
+
grads = self._maybe_pmean(grads)
|
319 |
+
params, opt_state = self._update_params(params, grads, opt_state,
|
320 |
+
algorithm_index)
|
321 |
+
lss = self._maybe_pmean(lss)
|
322 |
+
return lss, params, opt_state
|
323 |
+
|
324 |
+
def _predict(self, params, rng_key: hk.PRNGSequence, features: _Features,
|
325 |
+
algorithm_index: int, return_hints: bool,
|
326 |
+
return_all_outputs: bool):
|
327 |
+
outs, hint_preds = self.net_fn.apply(
|
328 |
+
params, rng_key, [features],
|
329 |
+
repred=True, algorithm_index=algorithm_index,
|
330 |
+
return_hints=return_hints,
|
331 |
+
return_all_outputs=return_all_outputs)
|
332 |
+
outs = decoders.postprocess(self._spec[algorithm_index],
|
333 |
+
outs,
|
334 |
+
sinkhorn_temperature=0.1,
|
335 |
+
sinkhorn_steps=50,
|
336 |
+
hard=True,
|
337 |
+
)
|
338 |
+
return outs, hint_preds
|
339 |
+
|
340 |
+
def compute_grad(
|
341 |
+
self,
|
342 |
+
rng_key: hk.PRNGSequence,
|
343 |
+
feedback: _Feedback,
|
344 |
+
algorithm_index: Optional[int] = None,
|
345 |
+
) -> Tuple[float, _Array]:
|
346 |
+
"""Compute gradients."""
|
347 |
+
|
348 |
+
if algorithm_index is None:
|
349 |
+
assert len(self._spec) == 1
|
350 |
+
algorithm_index = 0
|
351 |
+
assert algorithm_index >= 0
|
352 |
+
|
353 |
+
# Calculate gradients.
|
354 |
+
rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars
|
355 |
+
feedback = _maybe_pmap_data(feedback)
|
356 |
+
loss, grads = self.jitted_grad(
|
357 |
+
self._device_params, rng_keys, feedback, algorithm_index)
|
358 |
+
loss = _maybe_pick_first_pmapped(loss)
|
359 |
+
grads = _maybe_pick_first_pmapped(grads)
|
360 |
+
|
361 |
+
return loss, grads
|
362 |
+
|
363 |
+
def feedback(self, rng_key: hk.PRNGSequence, feedback: _Feedback,
|
364 |
+
algorithm_index=None) -> float:
|
365 |
+
if algorithm_index is None:
|
366 |
+
assert len(self._spec) == 1
|
367 |
+
algorithm_index = 0
|
368 |
+
# Calculate and apply gradients.
|
369 |
+
rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars
|
370 |
+
feedback = _maybe_pmap_data(feedback)
|
371 |
+
loss, self._device_params, self._device_opt_state = self.jitted_feedback(
|
372 |
+
self._device_params, rng_keys, feedback,
|
373 |
+
self._device_opt_state, algorithm_index)
|
374 |
+
loss = _maybe_pick_first_pmapped(loss)
|
375 |
+
return loss
|
376 |
+
|
377 |
+
def predict(self, rng_key: hk.PRNGSequence, features: _Features,
|
378 |
+
algorithm_index: Optional[int] = None,
|
379 |
+
return_hints: bool = False,
|
380 |
+
return_all_outputs: bool = False):
|
381 |
+
"""Model inference step."""
|
382 |
+
if algorithm_index is None:
|
383 |
+
assert len(self._spec) == 1
|
384 |
+
algorithm_index = 0
|
385 |
+
|
386 |
+
rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars
|
387 |
+
features = _maybe_pmap_data(features)
|
388 |
+
return _maybe_restack_from_pmap(
|
389 |
+
self.jitted_predict(
|
390 |
+
self._device_params, rng_keys, features,
|
391 |
+
algorithm_index,
|
392 |
+
return_hints,
|
393 |
+
return_all_outputs))
|
394 |
+
|
395 |
+
def _loss(self, params, rng_key, feedback, algorithm_index):
|
396 |
+
"""Calculates model loss f(feedback; params)."""
|
397 |
+
output_preds, hint_preds = self.net_fn.apply(
|
398 |
+
params, rng_key, [feedback.features],
|
399 |
+
repred=False,
|
400 |
+
algorithm_index=algorithm_index,
|
401 |
+
return_hints=True,
|
402 |
+
return_all_outputs=False)
|
403 |
+
|
404 |
+
nb_nodes = _nb_nodes(feedback, is_chunked=False)
|
405 |
+
lengths = feedback.features.lengths
|
406 |
+
total_loss = 0.0
|
407 |
+
|
408 |
+
# Calculate output loss.
|
409 |
+
for truth in feedback.outputs:
|
410 |
+
total_loss += losses.output_loss(
|
411 |
+
truth=truth,
|
412 |
+
pred=output_preds[truth.name],
|
413 |
+
nb_nodes=nb_nodes,
|
414 |
+
)
|
415 |
+
|
416 |
+
# Optionally accumulate hint losses.
|
417 |
+
if self.decode_hints:
|
418 |
+
for truth in feedback.features.hints:
|
419 |
+
total_loss += losses.hint_loss(
|
420 |
+
truth=truth,
|
421 |
+
preds=[x[truth.name] for x in hint_preds],
|
422 |
+
lengths=lengths,
|
423 |
+
nb_nodes=nb_nodes,
|
424 |
+
)
|
425 |
+
|
426 |
+
return total_loss
|
427 |
+
|
428 |
+
def _update_params(self, params, grads, opt_state, algorithm_index):
|
429 |
+
updates, opt_state = filter_null_grads(
|
430 |
+
grads, self.opt, opt_state, self.opt_state_skeleton, algorithm_index)
|
431 |
+
if self._freeze_processor:
|
432 |
+
params_subset = _filter_out_processor(params)
|
433 |
+
updates_subset = _filter_out_processor(updates)
|
434 |
+
assert len(params) > len(params_subset)
|
435 |
+
assert params_subset
|
436 |
+
new_params = optax.apply_updates(params_subset, updates_subset)
|
437 |
+
new_params = hk.data_structures.merge(params, new_params)
|
438 |
+
else:
|
439 |
+
new_params = optax.apply_updates(params, updates)
|
440 |
+
|
441 |
+
return new_params, opt_state
|
442 |
+
|
443 |
+
def update_model_params_accum(self, grads) -> None:
|
444 |
+
grads = _maybe_put_replicated(grads)
|
445 |
+
self._device_params, self._device_opt_state = self.jitted_accum_opt_update(
|
446 |
+
self._device_params, grads, self._device_opt_state, self.opt,
|
447 |
+
self._freeze_processor)
|
448 |
+
|
449 |
+
def verbose_loss(self, feedback: _Feedback, extra_info) -> Dict[str, _Array]:
|
450 |
+
"""Gets verbose loss information."""
|
451 |
+
hint_preds = extra_info
|
452 |
+
|
453 |
+
nb_nodes = _nb_nodes(feedback, is_chunked=False)
|
454 |
+
lengths = feedback.features.lengths
|
455 |
+
losses_ = {}
|
456 |
+
|
457 |
+
# Optionally accumulate hint losses.
|
458 |
+
if self.decode_hints:
|
459 |
+
for truth in feedback.features.hints:
|
460 |
+
losses_.update(
|
461 |
+
losses.hint_loss(
|
462 |
+
truth=truth,
|
463 |
+
preds=[x[truth.name] for x in hint_preds],
|
464 |
+
lengths=lengths,
|
465 |
+
nb_nodes=nb_nodes,
|
466 |
+
verbose=True,
|
467 |
+
))
|
468 |
+
|
469 |
+
return losses_
|
470 |
+
|
471 |
+
def restore_model(self, file_name: str, only_load_processor: bool = False):
|
472 |
+
"""Restore model from `file_name`."""
|
473 |
+
path = os.path.join(self.checkpoint_path, file_name)
|
474 |
+
with open(path, 'rb') as f:
|
475 |
+
restored_state = pickle.load(f)
|
476 |
+
if only_load_processor:
|
477 |
+
restored_params = _filter_in_processor(restored_state['params'])
|
478 |
+
else:
|
479 |
+
restored_params = restored_state['params']
|
480 |
+
self.params = hk.data_structures.merge(self.params, restored_params)
|
481 |
+
self.opt_state = restored_state['opt_state']
|
482 |
+
|
483 |
+
def save_model(self, file_name: str):
|
484 |
+
"""Save model (processor weights only) to `file_name`."""
|
485 |
+
os.makedirs(self.checkpoint_path, exist_ok=True)
|
486 |
+
to_save = {'params': self.params, 'opt_state': self.opt_state}
|
487 |
+
path = os.path.join(self.checkpoint_path, file_name)
|
488 |
+
with open(path, 'wb') as f:
|
489 |
+
pickle.dump(to_save, f)
|
490 |
+
|
491 |
+
|
492 |
+
class BaselineModelChunked(BaselineModel):
|
493 |
+
"""Model that processes time-chunked data.
|
494 |
+
|
495 |
+
Unlike `BaselineModel`, which processes full samples, `BaselineModelChunked`
|
496 |
+
processes fixed-timelength chunks of data. Each tensor of inputs and hints
|
497 |
+
has dimensions chunk_length x batch_size x ... The beginning of a new
|
498 |
+
sample withing the chunk is signalled by a tensor called `is_first` of
|
499 |
+
dimensions chunk_length x batch_size.
|
500 |
+
|
501 |
+
The chunked model is intended for training. For validation and test, use
|
502 |
+
`BaselineModel`.
|
503 |
+
"""
|
504 |
+
|
505 |
+
mp_states: List[List[nets.MessagePassingStateChunked]]
|
506 |
+
init_mp_states: List[List[nets.MessagePassingStateChunked]]
|
507 |
+
|
508 |
+
def _create_net_fns(self, hidden_dim, encode_hints, processor_factory,
|
509 |
+
use_lstm, encoder_init, dropout_prob,
|
510 |
+
hint_teacher_forcing, hint_repred_mode):
|
511 |
+
def _use_net(*args, **kwargs):
|
512 |
+
return nets.NetChunked(
|
513 |
+
self._spec, hidden_dim, encode_hints, self.decode_hints,
|
514 |
+
processor_factory, use_lstm, encoder_init, dropout_prob,
|
515 |
+
hint_teacher_forcing, hint_repred_mode,
|
516 |
+
self.nb_dims, self.nb_msg_passing_steps)(*args, **kwargs)
|
517 |
+
|
518 |
+
self.net_fn = hk.transform(_use_net)
|
519 |
+
pmap_args = dict(axis_name='batch', devices=jax.local_devices())
|
520 |
+
n_devices = jax.local_device_count()
|
521 |
+
func, static_arg, extra_args = (
|
522 |
+
(jax.jit, 'static_argnums', {}) if n_devices == 1 else
|
523 |
+
(jax.pmap, 'static_broadcasted_argnums', pmap_args))
|
524 |
+
pmean = functools.partial(jax.lax.pmean, axis_name='batch')
|
525 |
+
self._maybe_pmean = pmean if n_devices > 1 else lambda x: x
|
526 |
+
extra_args[static_arg] = 4
|
527 |
+
self.jitted_grad = func(self._compute_grad, **extra_args)
|
528 |
+
extra_args[static_arg] = 5
|
529 |
+
self.jitted_feedback = func(self._feedback, donate_argnums=[0, 4],
|
530 |
+
**extra_args)
|
531 |
+
extra_args[static_arg] = [3, 4]
|
532 |
+
self.jitted_accum_opt_update = func(accum_opt_update, donate_argnums=[0, 2],
|
533 |
+
**extra_args)
|
534 |
+
|
535 |
+
def _init_mp_state(self, features_list: List[List[_FeaturesChunked]],
|
536 |
+
rng_key: _Array):
|
537 |
+
def _empty_mp_state():
|
538 |
+
return nets.MessagePassingStateChunked( # pytype: disable=wrong-arg-types # numpy-scalars
|
539 |
+
inputs=None, hints=None, is_first=None,
|
540 |
+
hint_preds=None, hiddens=None, lstm_state=None)
|
541 |
+
empty_mp_states = [[_empty_mp_state() for _ in f] for f in features_list]
|
542 |
+
dummy_params = [self.net_fn.init(rng_key, f, e, False,
|
543 |
+
init_mp_state=True, algorithm_index=-1)
|
544 |
+
for (f, e) in zip(features_list, empty_mp_states)]
|
545 |
+
mp_states = [
|
546 |
+
self.net_fn.apply(d, rng_key, f, e, False,
|
547 |
+
init_mp_state=True, algorithm_index=-1)[1]
|
548 |
+
for (d, f, e) in zip(dummy_params, features_list, empty_mp_states)]
|
549 |
+
return mp_states
|
550 |
+
|
551 |
+
def init(self,
|
552 |
+
features: List[List[_FeaturesChunked]],
|
553 |
+
seed: _Seed):
|
554 |
+
self.mp_states = self._init_mp_state(features,
|
555 |
+
jax.random.PRNGKey(seed)) # pytype: disable=wrong-arg-types # jax-ndarray
|
556 |
+
self.init_mp_states = [list(x) for x in self.mp_states]
|
557 |
+
self.params = self.net_fn.init(
|
558 |
+
jax.random.PRNGKey(seed), features[0], self.mp_states[0], # pytype: disable=wrong-arg-types # jax-ndarray
|
559 |
+
True, init_mp_state=False, algorithm_index=-1)
|
560 |
+
self.opt_state = self.opt.init(self.params)
|
561 |
+
# We will use the optimizer state skeleton for traversal when we
|
562 |
+
# want to avoid updating the state of params of untrained algorithms.
|
563 |
+
self.opt_state_skeleton = self.opt.init(jnp.zeros(1))
|
564 |
+
|
565 |
+
def predict(self, rng_key: hk.PRNGSequence, features: _FeaturesChunked,
|
566 |
+
algorithm_index: Optional[int] = None):
|
567 |
+
"""Inference not implemented. Chunked model intended for training only."""
|
568 |
+
raise NotImplementedError
|
569 |
+
|
570 |
+
def _loss(self, params, rng_key, feedback, mp_state, algorithm_index):
|
571 |
+
(output_preds, hint_preds), mp_state = self.net_fn.apply(
|
572 |
+
params, rng_key, [feedback.features],
|
573 |
+
[mp_state],
|
574 |
+
repred=False,
|
575 |
+
init_mp_state=False,
|
576 |
+
algorithm_index=algorithm_index)
|
577 |
+
|
578 |
+
nb_nodes = _nb_nodes(feedback, is_chunked=True)
|
579 |
+
|
580 |
+
total_loss = 0.0
|
581 |
+
is_first = feedback.features.is_first
|
582 |
+
is_last = feedback.features.is_last
|
583 |
+
|
584 |
+
# Calculate output loss.
|
585 |
+
for truth in feedback.outputs:
|
586 |
+
total_loss += losses.output_loss_chunked(
|
587 |
+
truth=truth,
|
588 |
+
pred=output_preds[truth.name],
|
589 |
+
is_last=is_last,
|
590 |
+
nb_nodes=nb_nodes,
|
591 |
+
)
|
592 |
+
|
593 |
+
# Optionally accumulate hint losses.
|
594 |
+
if self.decode_hints:
|
595 |
+
for truth in feedback.features.hints:
|
596 |
+
loss = losses.hint_loss_chunked(
|
597 |
+
truth=truth,
|
598 |
+
pred=hint_preds[truth.name],
|
599 |
+
is_first=is_first,
|
600 |
+
nb_nodes=nb_nodes,
|
601 |
+
)
|
602 |
+
total_loss += loss
|
603 |
+
|
604 |
+
return total_loss, (mp_state,)
|
605 |
+
|
606 |
+
def _compute_grad(self, params, rng_key, feedback, mp_state, algorithm_index):
|
607 |
+
(lss, (mp_state,)), grads = jax.value_and_grad(self._loss, has_aux=True)(
|
608 |
+
params, rng_key, feedback, mp_state, algorithm_index)
|
609 |
+
return self._maybe_pmean(lss), mp_state, self._maybe_pmean(grads)
|
610 |
+
|
611 |
+
def _feedback(self, params, rng_key, feedback, mp_state, opt_state,
|
612 |
+
algorithm_index):
|
613 |
+
(lss, (mp_state,)), grads = jax.value_and_grad(self._loss, has_aux=True)(
|
614 |
+
params, rng_key, feedback, mp_state, algorithm_index)
|
615 |
+
grads = self._maybe_pmean(grads)
|
616 |
+
params, opt_state = self._update_params(params, grads, opt_state,
|
617 |
+
algorithm_index)
|
618 |
+
lss = self._maybe_pmean(lss)
|
619 |
+
return lss, params, opt_state, mp_state
|
620 |
+
|
621 |
+
def compute_grad(
|
622 |
+
self,
|
623 |
+
rng_key: hk.PRNGSequence,
|
624 |
+
feedback: _Feedback,
|
625 |
+
algorithm_index: Optional[Tuple[int, int]] = None,
|
626 |
+
) -> Tuple[float, _Array]:
|
627 |
+
"""Compute gradients."""
|
628 |
+
|
629 |
+
if algorithm_index is None:
|
630 |
+
assert len(self._spec) == 1
|
631 |
+
algorithm_index = (0, 0)
|
632 |
+
length_index, algorithm_index = algorithm_index
|
633 |
+
# Reusing init_mp_state improves performance.
|
634 |
+
# The next, commented out line, should be used for proper state keeping.
|
635 |
+
# mp_state = self.mp_states[length_index][algorithm_index]
|
636 |
+
mp_state = self.init_mp_states[length_index][algorithm_index]
|
637 |
+
rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars
|
638 |
+
feedback = _maybe_pmap_reshape(feedback, split_axis=1)
|
639 |
+
mp_state = _maybe_pmap_reshape(mp_state, split_axis=0)
|
640 |
+
|
641 |
+
loss, mp_state, grads = self.jitted_grad(
|
642 |
+
self._device_params, rng_keys, feedback, mp_state, algorithm_index)
|
643 |
+
loss = _maybe_pick_first_pmapped(loss)
|
644 |
+
grads = _maybe_pick_first_pmapped(grads)
|
645 |
+
mp_state = _maybe_restack_from_pmap(mp_state)
|
646 |
+
self.mp_states[length_index][algorithm_index] = mp_state
|
647 |
+
return loss, grads
|
648 |
+
|
649 |
+
def feedback(self, rng_key: hk.PRNGSequence, feedback: _Feedback,
|
650 |
+
algorithm_index=None) -> float:
|
651 |
+
if algorithm_index is None:
|
652 |
+
assert len(self._spec) == 1
|
653 |
+
algorithm_index = (0, 0)
|
654 |
+
length_index, algorithm_index = algorithm_index
|
655 |
+
# Reusing init_mp_state improves performance.
|
656 |
+
# The next, commented out line, should be used for proper state keeping.
|
657 |
+
# mp_state = self.mp_states[length_index][algorithm_index]
|
658 |
+
mp_state = self.init_mp_states[length_index][algorithm_index]
|
659 |
+
rng_keys = _maybe_pmap_rng_key(rng_key) # pytype: disable=wrong-arg-types # numpy-scalars
|
660 |
+
feedback = _maybe_pmap_reshape(feedback, split_axis=1)
|
661 |
+
mp_state = _maybe_pmap_reshape(mp_state, split_axis=0)
|
662 |
+
loss, self._device_params, self._device_opt_state, mp_state = (
|
663 |
+
self.jitted_feedback(
|
664 |
+
self._device_params, rng_keys, feedback,
|
665 |
+
mp_state, self._device_opt_state, algorithm_index))
|
666 |
+
loss = _maybe_pick_first_pmapped(loss)
|
667 |
+
mp_state = _maybe_restack_from_pmap(mp_state)
|
668 |
+
self.mp_states[length_index][algorithm_index] = mp_state
|
669 |
+
return loss
|
670 |
+
|
671 |
+
def verbose_loss(self, *args, **kwargs):
|
672 |
+
raise NotImplementedError
|
673 |
+
|
674 |
+
|
675 |
+
def _nb_nodes(feedback: _Feedback, is_chunked) -> int:
|
676 |
+
for inp in feedback.features.inputs:
|
677 |
+
if inp.location in [_Location.NODE, _Location.EDGE]:
|
678 |
+
if is_chunked:
|
679 |
+
return inp.data.shape[2] # inputs are time x batch x nodes x ...
|
680 |
+
else:
|
681 |
+
return inp.data.shape[1] # inputs are batch x nodes x ...
|
682 |
+
assert False
|
683 |
+
|
684 |
+
|
685 |
+
def _param_in_processor(module_name):
|
686 |
+
return processors.PROCESSOR_TAG in module_name
|
687 |
+
|
688 |
+
|
689 |
+
def _filter_out_processor(params: hk.Params) -> hk.Params:
|
690 |
+
return hk.data_structures.filter(
|
691 |
+
lambda module_name, n, v: not _param_in_processor(module_name), params)
|
692 |
+
|
693 |
+
|
694 |
+
def _filter_in_processor(params: hk.Params) -> hk.Params:
|
695 |
+
return hk.data_structures.filter(
|
696 |
+
lambda module_name, n, v: _param_in_processor(module_name), params)
|
697 |
+
|
698 |
+
|
699 |
+
def _is_not_done_broadcast(lengths, i, tensor):
|
700 |
+
is_not_done = (lengths > i + 1) * 1.0
|
701 |
+
while len(is_not_done.shape) < len(tensor.shape):
|
702 |
+
is_not_done = jnp.expand_dims(is_not_done, -1)
|
703 |
+
return is_not_done
|
704 |
+
|
705 |
+
|
706 |
+
def accum_opt_update(params, grads, opt_state, opt, freeze_processor):
|
707 |
+
"""Update params from gradients collected from several algorithms."""
|
708 |
+
# Average the gradients over all algos
|
709 |
+
grads = jax.tree_util.tree_map(
|
710 |
+
lambda *x: sum(x) / (sum([jnp.any(k) for k in x]) + 1e-12), *grads)
|
711 |
+
updates, opt_state = opt.update(grads, opt_state)
|
712 |
+
if freeze_processor:
|
713 |
+
params_subset = _filter_out_processor(params)
|
714 |
+
assert len(params) > len(params_subset)
|
715 |
+
assert params_subset
|
716 |
+
updates_subset = _filter_out_processor(updates)
|
717 |
+
new_params = optax.apply_updates(params_subset, updates_subset)
|
718 |
+
new_params = hk.data_structures.merge(params, new_params)
|
719 |
+
else:
|
720 |
+
new_params = optax.apply_updates(params, updates)
|
721 |
+
|
722 |
+
return new_params, opt_state
|
723 |
+
|
724 |
+
|
725 |
+
@functools.partial(jax.jit, static_argnames=['opt'])
|
726 |
+
def opt_update(opt, flat_grads, flat_opt_state):
|
727 |
+
return opt.update(flat_grads, flat_opt_state)
|
728 |
+
|
729 |
+
|
730 |
+
def filter_null_grads(grads, opt, opt_state, opt_state_skeleton, algo_idx):
|
731 |
+
"""Compute updates ignoring params that have no gradients.
|
732 |
+
|
733 |
+
This prevents untrained params (e.g., encoders/decoders for algorithms
|
734 |
+
that are not being trained) to accumulate, e.g., momentum from spurious
|
735 |
+
zero gradients.
|
736 |
+
|
737 |
+
Note: this works as intended for "per-parameter" optimizer state, such as
|
738 |
+
momentum. However, when the optimizer has some global state (such as the
|
739 |
+
step counts in Adam), the global state will be updated every time,
|
740 |
+
affecting also future updates of parameters that had null gradients in the
|
741 |
+
current step.
|
742 |
+
|
743 |
+
Args:
|
744 |
+
grads: Gradients for all parameters.
|
745 |
+
opt: Optax optimizer.
|
746 |
+
opt_state: Optimizer state.
|
747 |
+
opt_state_skeleton: A "skeleton" of optimizer state that has been
|
748 |
+
initialized with scalar parameters. This serves to traverse each parameter
|
749 |
+
of the otpimizer state during the opt state update.
|
750 |
+
algo_idx: Index of algorithm, to filter out unused encoders/decoders.
|
751 |
+
If None, no filtering happens.
|
752 |
+
Returns:
|
753 |
+
Updates and new optimizer state, where the parameters with null gradient
|
754 |
+
have not been taken into account.
|
755 |
+
"""
|
756 |
+
def _keep_in_algo(k, v):
|
757 |
+
"""Ignore params of encoders/decoders irrelevant for this algo."""
|
758 |
+
# Note: in shared pointer decoder modes, we should exclude shared params
|
759 |
+
# for algos that do not have pointer outputs.
|
760 |
+
if ((processors.PROCESSOR_TAG in k) or
|
761 |
+
(f'algo_{algo_idx}_' in k)):
|
762 |
+
return v
|
763 |
+
return jax.tree_util.tree_map(lambda x: None, v)
|
764 |
+
|
765 |
+
if algo_idx is None:
|
766 |
+
masked_grads = grads
|
767 |
+
else:
|
768 |
+
masked_grads = {k: _keep_in_algo(k, v) for k, v in grads.items()}
|
769 |
+
flat_grads, treedef = jax.tree_util.tree_flatten(masked_grads)
|
770 |
+
flat_opt_state = jax.tree_util.tree_map(
|
771 |
+
lambda _, x: x # pylint:disable=g-long-lambda
|
772 |
+
if isinstance(x, (np.ndarray, jax.Array))
|
773 |
+
else treedef.flatten_up_to(x),
|
774 |
+
opt_state_skeleton,
|
775 |
+
opt_state,
|
776 |
+
)
|
777 |
+
|
778 |
+
# Compute updates only for the params with gradient.
|
779 |
+
flat_updates, flat_opt_state = opt_update(opt, flat_grads, flat_opt_state)
|
780 |
+
|
781 |
+
def unflatten(flat, original):
|
782 |
+
"""Restore tree structure, filling missing (None) leaves with original."""
|
783 |
+
if isinstance(flat, (np.ndarray, jax.Array)):
|
784 |
+
return flat
|
785 |
+
return jax.tree_util.tree_map(lambda x, y: x if y is None else y, original,
|
786 |
+
treedef.unflatten(flat))
|
787 |
+
|
788 |
+
# Restore the state and updates tree structure.
|
789 |
+
new_opt_state = jax.tree_util.tree_map(lambda _, x, y: unflatten(x, y),
|
790 |
+
opt_state_skeleton, flat_opt_state,
|
791 |
+
opt_state)
|
792 |
+
updates = unflatten(flat_updates,
|
793 |
+
jax.tree_util.tree_map(lambda x: 0., grads))
|
794 |
+
return updates, new_opt_state
|
benchmarks/CLRS/env/baselines_test.py
ADDED
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Unit tests for `baselines.py`."""
|
17 |
+
|
18 |
+
import copy
|
19 |
+
import functools
|
20 |
+
from typing import Generator
|
21 |
+
|
22 |
+
from absl.testing import absltest
|
23 |
+
from absl.testing import parameterized
|
24 |
+
import chex
|
25 |
+
|
26 |
+
from clrs._src import baselines
|
27 |
+
from clrs._src import dataset
|
28 |
+
from clrs._src import probing
|
29 |
+
from clrs._src import processors
|
30 |
+
from clrs._src import samplers
|
31 |
+
from clrs._src import specs
|
32 |
+
|
33 |
+
import haiku as hk
|
34 |
+
import jax
|
35 |
+
import numpy as np
|
36 |
+
|
37 |
+
_Array = np.ndarray
|
38 |
+
|
39 |
+
|
40 |
+
def _error(x, y):
|
41 |
+
return np.sum(np.abs(x-y))
|
42 |
+
|
43 |
+
|
44 |
+
def _make_sampler(algo: str, length: int) -> samplers.Sampler:
|
45 |
+
sampler, _ = samplers.build_sampler(
|
46 |
+
algo,
|
47 |
+
seed=samplers.CLRS30['val']['seed'],
|
48 |
+
num_samples=samplers.CLRS30['val']['num_samples'],
|
49 |
+
length=length,
|
50 |
+
)
|
51 |
+
return sampler
|
52 |
+
|
53 |
+
|
54 |
+
def _without_permutation(feedback):
|
55 |
+
"""Replace should-be permutations with pointers."""
|
56 |
+
outputs = []
|
57 |
+
for x in feedback.outputs:
|
58 |
+
if x.type_ != specs.Type.SHOULD_BE_PERMUTATION:
|
59 |
+
outputs.append(x)
|
60 |
+
continue
|
61 |
+
assert x.location == specs.Location.NODE
|
62 |
+
outputs.append(probing.DataPoint(name=x.name, location=x.location,
|
63 |
+
type_=specs.Type.POINTER, data=x.data))
|
64 |
+
return feedback._replace(outputs=outputs)
|
65 |
+
|
66 |
+
|
67 |
+
def _make_iterable_sampler(
|
68 |
+
algo: str, batch_size: int,
|
69 |
+
length: int) -> Generator[samplers.Feedback, None, None]:
|
70 |
+
sampler = _make_sampler(algo, length)
|
71 |
+
while True:
|
72 |
+
yield _without_permutation(sampler.next(batch_size))
|
73 |
+
|
74 |
+
|
75 |
+
def _remove_permutation_from_spec(spec):
|
76 |
+
"""Modify spec to turn permutation type to pointer."""
|
77 |
+
new_spec = {}
|
78 |
+
for k in spec:
|
79 |
+
if (spec[k][1] == specs.Location.NODE and
|
80 |
+
spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION):
|
81 |
+
new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER)
|
82 |
+
else:
|
83 |
+
new_spec[k] = spec[k]
|
84 |
+
return new_spec
|
85 |
+
|
86 |
+
|
87 |
+
class BaselinesTest(parameterized.TestCase):
|
88 |
+
|
89 |
+
def test_full_vs_chunked(self):
|
90 |
+
"""Test that chunking does not affect gradients."""
|
91 |
+
|
92 |
+
batch_size = 4
|
93 |
+
length = 8
|
94 |
+
algo = 'insertion_sort'
|
95 |
+
spec = _remove_permutation_from_spec(specs.SPECS[algo])
|
96 |
+
rng_key = jax.random.PRNGKey(42)
|
97 |
+
|
98 |
+
full_ds = _make_iterable_sampler(algo, batch_size, length)
|
99 |
+
chunked_ds = dataset.chunkify(
|
100 |
+
_make_iterable_sampler(algo, batch_size, length),
|
101 |
+
length)
|
102 |
+
double_chunked_ds = dataset.chunkify(
|
103 |
+
_make_iterable_sampler(algo, batch_size, length),
|
104 |
+
length * 2)
|
105 |
+
|
106 |
+
full_batches = [next(full_ds) for _ in range(2)]
|
107 |
+
chunked_batches = [next(chunked_ds) for _ in range(2)]
|
108 |
+
double_chunk_batch = next(double_chunked_ds)
|
109 |
+
|
110 |
+
with chex.fake_jit(): # jitting makes test longer
|
111 |
+
|
112 |
+
processor_factory = processors.get_processor_factory(
|
113 |
+
'mpnn', use_ln=False, nb_triplet_fts=0)
|
114 |
+
common_args = dict(processor_factory=processor_factory, hidden_dim=8,
|
115 |
+
learning_rate=0.01,
|
116 |
+
decode_hints=True, encode_hints=True)
|
117 |
+
|
118 |
+
b_full = baselines.BaselineModel(
|
119 |
+
spec, dummy_trajectory=full_batches[0], **common_args)
|
120 |
+
b_full.init(full_batches[0].features, seed=42) # pytype: disable=wrong-arg-types # jax-ndarray
|
121 |
+
full_params = b_full.params
|
122 |
+
full_loss_0 = b_full.feedback(rng_key, full_batches[0])
|
123 |
+
b_full.params = full_params
|
124 |
+
full_loss_1 = b_full.feedback(rng_key, full_batches[1])
|
125 |
+
new_full_params = b_full.params
|
126 |
+
|
127 |
+
b_chunked = baselines.BaselineModelChunked(
|
128 |
+
spec, dummy_trajectory=chunked_batches[0], **common_args)
|
129 |
+
b_chunked.init([[chunked_batches[0].features]], seed=42) # pytype: disable=wrong-arg-types # jax-ndarray
|
130 |
+
chunked_params = b_chunked.params
|
131 |
+
jax.tree_util.tree_map(np.testing.assert_array_equal, full_params,
|
132 |
+
chunked_params)
|
133 |
+
chunked_loss_0 = b_chunked.feedback(rng_key, chunked_batches[0])
|
134 |
+
b_chunked.params = chunked_params
|
135 |
+
chunked_loss_1 = b_chunked.feedback(rng_key, chunked_batches[1])
|
136 |
+
new_chunked_params = b_chunked.params
|
137 |
+
|
138 |
+
b_chunked.params = chunked_params
|
139 |
+
double_chunked_loss = b_chunked.feedback(rng_key, double_chunk_batch)
|
140 |
+
|
141 |
+
# Test that losses match
|
142 |
+
np.testing.assert_allclose(full_loss_0, chunked_loss_0, rtol=1e-4)
|
143 |
+
np.testing.assert_allclose(full_loss_1, chunked_loss_1, rtol=1e-4)
|
144 |
+
np.testing.assert_allclose(full_loss_0 + full_loss_1,
|
145 |
+
2 * double_chunked_loss,
|
146 |
+
rtol=1e-4)
|
147 |
+
|
148 |
+
# Test that gradients are the same (parameters changed equally).
|
149 |
+
# First check that gradients were not zero, i.e., parameters have changed.
|
150 |
+
param_change, _ = jax.tree_util.tree_flatten(
|
151 |
+
jax.tree_util.tree_map(_error, full_params, new_full_params))
|
152 |
+
self.assertGreater(np.mean(param_change), 0.1)
|
153 |
+
# Now check that full and chunked gradients are the same.
|
154 |
+
jax.tree_util.tree_map(
|
155 |
+
functools.partial(np.testing.assert_allclose, rtol=1e-4),
|
156 |
+
new_full_params, new_chunked_params)
|
157 |
+
|
158 |
+
def test_multi_vs_single(self):
|
159 |
+
"""Test that multi = single when we only train one of the algorithms."""
|
160 |
+
|
161 |
+
batch_size = 4
|
162 |
+
length = 16
|
163 |
+
algos = ['insertion_sort', 'activity_selector', 'bfs']
|
164 |
+
spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos]
|
165 |
+
rng_key = jax.random.PRNGKey(42)
|
166 |
+
|
167 |
+
full_ds = [_make_iterable_sampler(algo, batch_size, length)
|
168 |
+
for algo in algos]
|
169 |
+
full_batches = [next(ds) for ds in full_ds]
|
170 |
+
full_batches_2 = [next(ds) for ds in full_ds]
|
171 |
+
|
172 |
+
with chex.fake_jit(): # jitting makes test longer
|
173 |
+
|
174 |
+
processor_factory = processors.get_processor_factory(
|
175 |
+
'mpnn', use_ln=False, nb_triplet_fts=0)
|
176 |
+
common_args = dict(processor_factory=processor_factory, hidden_dim=8,
|
177 |
+
learning_rate=0.01,
|
178 |
+
decode_hints=True, encode_hints=True)
|
179 |
+
|
180 |
+
b_single = baselines.BaselineModel(
|
181 |
+
spec[0], dummy_trajectory=full_batches[0], **common_args)
|
182 |
+
b_multi = baselines.BaselineModel(
|
183 |
+
spec, dummy_trajectory=full_batches, **common_args)
|
184 |
+
b_single.init(full_batches[0].features, seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
|
185 |
+
b_multi.init([f.features for f in full_batches], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
|
186 |
+
|
187 |
+
single_params = []
|
188 |
+
single_losses = []
|
189 |
+
multi_params = []
|
190 |
+
multi_losses = []
|
191 |
+
|
192 |
+
single_params.append(copy.deepcopy(b_single.params))
|
193 |
+
single_losses.append(b_single.feedback(rng_key, full_batches[0]))
|
194 |
+
single_params.append(copy.deepcopy(b_single.params))
|
195 |
+
single_losses.append(b_single.feedback(rng_key, full_batches_2[0]))
|
196 |
+
single_params.append(copy.deepcopy(b_single.params))
|
197 |
+
|
198 |
+
multi_params.append(copy.deepcopy(b_multi.params))
|
199 |
+
multi_losses.append(b_multi.feedback(rng_key, full_batches[0],
|
200 |
+
algorithm_index=0))
|
201 |
+
multi_params.append(copy.deepcopy(b_multi.params))
|
202 |
+
multi_losses.append(b_multi.feedback(rng_key, full_batches_2[0],
|
203 |
+
algorithm_index=0))
|
204 |
+
multi_params.append(copy.deepcopy(b_multi.params))
|
205 |
+
|
206 |
+
# Test that losses match
|
207 |
+
np.testing.assert_array_equal(single_losses, multi_losses)
|
208 |
+
# Test that loss decreased
|
209 |
+
assert single_losses[1] < single_losses[0]
|
210 |
+
|
211 |
+
# Test that param changes were the same in single and multi-algorithm
|
212 |
+
for single, multi in zip(single_params, multi_params):
|
213 |
+
assert hk.data_structures.is_subset(subset=single, superset=multi)
|
214 |
+
for module_name, params in single.items():
|
215 |
+
jax.tree_util.tree_map(np.testing.assert_array_equal, params,
|
216 |
+
multi[module_name])
|
217 |
+
|
218 |
+
# Test that params change for the trained algorithm, but not the others
|
219 |
+
for module_name, params in multi_params[0].items():
|
220 |
+
param_changes = jax.tree_util.tree_map(lambda a, b: np.sum(np.abs(a - b)),
|
221 |
+
params,
|
222 |
+
multi_params[1][module_name])
|
223 |
+
param_change = sum(param_changes.values())
|
224 |
+
if module_name in single_params[0]: # params of trained algorithm
|
225 |
+
assert param_change > 1e-3
|
226 |
+
else: # params of non-trained algorithms
|
227 |
+
assert param_change == 0.0
|
228 |
+
|
229 |
+
@parameterized.parameters(True, False)
|
230 |
+
def test_multi_algorithm_idx(self, is_chunked):
|
231 |
+
"""Test that algorithm selection works as intended."""
|
232 |
+
|
233 |
+
batch_size = 4
|
234 |
+
length = 8
|
235 |
+
algos = ['insertion_sort', 'activity_selector', 'bfs']
|
236 |
+
spec = [_remove_permutation_from_spec(specs.SPECS[algo]) for algo in algos]
|
237 |
+
rng_key = jax.random.PRNGKey(42)
|
238 |
+
|
239 |
+
if is_chunked:
|
240 |
+
ds = [dataset.chunkify(_make_iterable_sampler(algo, batch_size, length),
|
241 |
+
2 * length) for algo in algos]
|
242 |
+
else:
|
243 |
+
ds = [_make_iterable_sampler(algo, batch_size, length) for algo in algos]
|
244 |
+
batches = [next(d) for d in ds]
|
245 |
+
|
246 |
+
processor_factory = processors.get_processor_factory(
|
247 |
+
'mpnn', use_ln=False, nb_triplet_fts=0)
|
248 |
+
common_args = dict(processor_factory=processor_factory, hidden_dim=8,
|
249 |
+
learning_rate=0.01,
|
250 |
+
decode_hints=True, encode_hints=True)
|
251 |
+
if is_chunked:
|
252 |
+
baseline = baselines.BaselineModelChunked(
|
253 |
+
spec, dummy_trajectory=batches, **common_args)
|
254 |
+
baseline.init([[f.features for f in batches]], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
|
255 |
+
else:
|
256 |
+
baseline = baselines.BaselineModel(
|
257 |
+
spec, dummy_trajectory=batches, **common_args)
|
258 |
+
baseline.init([f.features for f in batches], seed=0) # pytype: disable=wrong-arg-types # jax-ndarray
|
259 |
+
|
260 |
+
# Find out what parameters change when we train each algorithm
|
261 |
+
def _change(x, y):
|
262 |
+
changes = {}
|
263 |
+
for module_name, params in x.items():
|
264 |
+
changes[module_name] = sum(
|
265 |
+
jax.tree_util.tree_map(
|
266 |
+
lambda a, b: np.sum(np.abs(a-b)), params, y[module_name]
|
267 |
+
).values())
|
268 |
+
return changes
|
269 |
+
|
270 |
+
param_changes = []
|
271 |
+
for algo_idx in range(len(algos)):
|
272 |
+
init_params = copy.deepcopy(baseline.params)
|
273 |
+
_ = baseline.feedback(
|
274 |
+
rng_key,
|
275 |
+
batches[algo_idx],
|
276 |
+
algorithm_index=(0, algo_idx) if is_chunked else algo_idx)
|
277 |
+
param_changes.append(_change(init_params, baseline.params))
|
278 |
+
|
279 |
+
# Test that non-changing parameters correspond to encoders/decoders
|
280 |
+
# associated with the non-trained algorithms
|
281 |
+
unchanged = [[k for k in pc if pc[k] == 0] for pc in param_changes]
|
282 |
+
|
283 |
+
def _get_other_algos(algo_idx, modules):
|
284 |
+
return set([k for k in modules if '_construct_encoders_decoders' in k
|
285 |
+
and f'algo_{algo_idx}' not in k])
|
286 |
+
|
287 |
+
for algo_idx in range(len(algos)):
|
288 |
+
expected_unchanged = _get_other_algos(algo_idx, baseline.params.keys())
|
289 |
+
self.assertNotEmpty(expected_unchanged)
|
290 |
+
self.assertSetEqual(expected_unchanged, set(unchanged[algo_idx]))
|
291 |
+
|
292 |
+
|
293 |
+
if __name__ == '__main__':
|
294 |
+
absltest.main()
|
benchmarks/CLRS/env/data_description.txt
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
The CLRS Algorithmic Reasoning Benchmark
|
2 |
+
|
3 |
+
Learning representations of algorithms is an emerging area of machine learning, seeking to bridge concepts from neural networks with classical algorithms. The CLRS Algorithmic Reasoning Benchmark (CLRS) consolidates and extends previous work toward evaluation algorithmic reasoning by providing a suite of implementations of classical algorithms. These algorithms have been selected from the third edition of the standard Introduction to Algorithms by Cormen, Leiserson, Rivest and Stein.
|
4 |
+
|
5 |
+
Algorithms as graphs
|
6 |
+
CLRS implements the selected algorithms in an idiomatic way, which aligns as closely as possible to the original CLRS 3ed pseudocode. By controlling the input data distribution to conform to the preconditions we are able to automatically generate input/output pairs. We additionally provide trajectories of "hints" that expose the internal state of each algorithm, to both optionally simplify the learning challenge and to distinguish between different algorithms that solve the same overall task (e.g. sorting).
|
7 |
+
|
8 |
+
In the most generic sense, algorithms can be seen as manipulating sets of objects, along with any relations between them (which can themselves be decomposed into binary relations). Accordingly, we study all of the algorithms in this benchmark using a graph representation. In the event that objects obey a more strict ordered structure (e.g. arrays or rooted trees), we impose this ordering through inclusion of predecessor links.
|
9 |
+
|
10 |
+
How it works
|
11 |
+
For each algorithm, we provide a canonical set of train, eval and test trajectories for benchmarking out-of-distribution generalization.
|
12 |
+
|
13 |
+
Trajectories Problem Size
|
14 |
+
Train 1000 16
|
15 |
+
Eval 32 x multiplier 16
|
16 |
+
Test 32 x multiplier 64
|
17 |
+
Here, "problem size" refers to e.g. the length of an array or number of nodes in a graph, depending on the algorithm. "multiplier" is an algorithm-specific factor that increases the number of available eval and test trajectories to compensate for paucity of evaluation signals. "multiplier" is 1 for all algorithms except:
|
18 |
+
|
19 |
+
Maximum subarray (Kadane), for which "multiplier" is 32.
|
20 |
+
Quick select, minimum, binary search, string matchers (both naive and KMP), and segment intersection, for which "multiplier" is 64.
|
21 |
+
The trajectories can be used like so:
|
22 |
+
|
23 |
+
train_ds, num_samples, spec = clrs.create_dataset(
|
24 |
+
folder='/tmp/CLRS30', algorithm='bfs',
|
25 |
+
split='train', batch_size=32)
|
26 |
+
|
27 |
+
for i, feedback in enumerate(train_ds.as_numpy_iterator()):
|
28 |
+
if i == 0:
|
29 |
+
model.init(feedback.features, initial_seed)
|
30 |
+
loss = model.feedback(rng_key, feedback)
|
31 |
+
Here, feedback is a namedtuple with the following structure:
|
32 |
+
|
33 |
+
Feedback = collections.namedtuple('Feedback', ['features', 'outputs'])
|
34 |
+
Features = collections.namedtuple('Features', ['inputs', 'hints', 'lengths'])
|
35 |
+
where the content of Features can be used for training and outputs is reserved for evaluation. Each field of the tuple is an ndarray with a leading batch dimension. Because hints are provided for the full algorithm trajectory, these contain an additional time dimension padded up to the maximum length max(T) of any trajectory within the dataset. The lengths field specifies the true length t <= max(T) for each trajectory, which can be used e.g. for loss masking.
|
benchmarks/CLRS/env/dataset.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""CLRS dataset."""
|
16 |
+
|
17 |
+
import dataclasses
|
18 |
+
|
19 |
+
import functools
|
20 |
+
from typing import Iterator
|
21 |
+
|
22 |
+
from clrs._src import probing
|
23 |
+
from clrs._src import samplers
|
24 |
+
from clrs._src import specs
|
25 |
+
|
26 |
+
import jax
|
27 |
+
import numpy as np
|
28 |
+
import tensorflow as tf
|
29 |
+
import tensorflow_datasets as tfds
|
30 |
+
|
31 |
+
|
32 |
+
def _correct_axis_filtering(tensor, index, name):
|
33 |
+
if 'hint_' in name:
|
34 |
+
return tensor[:, index]
|
35 |
+
else:
|
36 |
+
return tensor[index]
|
37 |
+
|
38 |
+
|
39 |
+
@dataclasses.dataclass
|
40 |
+
class CLRSConfig(tfds.core.BuilderConfig):
|
41 |
+
"""Specify the split in the variant because they have different shapes."""
|
42 |
+
split: str = ''
|
43 |
+
|
44 |
+
|
45 |
+
DEFAULT_BUILDER_CONFIGS = []
|
46 |
+
|
47 |
+
|
48 |
+
def _build_default_builder_configs():
|
49 |
+
for split in ['train', 'val', 'test']:
|
50 |
+
for alg in specs.CLRS_30_ALGS:
|
51 |
+
DEFAULT_BUILDER_CONFIGS.append(
|
52 |
+
CLRSConfig(name=f'{alg}_{split}', split=split))
|
53 |
+
|
54 |
+
|
55 |
+
_build_default_builder_configs()
|
56 |
+
|
57 |
+
|
58 |
+
class CLRSDataset(tfds.core.GeneratorBasedBuilder):
|
59 |
+
"""DatasetBuilder for my_dataset dataset."""
|
60 |
+
|
61 |
+
VERSION = tfds.core.Version('1.0.0')
|
62 |
+
RELEASE_NOTES = {
|
63 |
+
'1.0.0': 'Initial release.',
|
64 |
+
}
|
65 |
+
BUILDER_CONFIGS = DEFAULT_BUILDER_CONFIGS
|
66 |
+
|
67 |
+
_instantiated_dataset = None
|
68 |
+
_instantiated_dataset_name = ''
|
69 |
+
_instantiated_dataset_split = ''
|
70 |
+
|
71 |
+
def _num_samples(self, algorithm_name):
|
72 |
+
num_samples = samplers.CLRS30[self._builder_config.split]['num_samples'] # pytype: disable=attribute-error # always-use-return-annotations
|
73 |
+
if self._builder_config.split != 'train': # pytype: disable=attribute-error # always-use-return-annotations
|
74 |
+
# Generate more samples for those algorithms in which the number of
|
75 |
+
# signals is small.
|
76 |
+
num_samples *= specs.CLRS_30_ALGS_SETTINGS[algorithm_name][
|
77 |
+
'num_samples_multiplier']
|
78 |
+
return num_samples
|
79 |
+
|
80 |
+
def _create_data(self, single_sample):
|
81 |
+
algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1])
|
82 |
+
num_samples = self._num_samples(algorithm_name)
|
83 |
+
sampler, _ = samplers.build_sampler(
|
84 |
+
algorithm_name,
|
85 |
+
seed=samplers.CLRS30[self._builder_config.split]['seed'], # pytype: disable=attribute-error # always-use-return-annotations
|
86 |
+
num_samples=num_samples,
|
87 |
+
length=samplers.CLRS30[self._builder_config.split]['length'], # pytype: disable=attribute-error # always-use-return-annotations
|
88 |
+
)
|
89 |
+
sampled_dataset = sampler.next(batch_size=1 if single_sample else None)
|
90 |
+
data = {'input_' + t.name: t.data for t in sampled_dataset.features.inputs}
|
91 |
+
# All other data points have input_, hint_, and output_ prefixes, so we
|
92 |
+
# guarantee that this key is unused.
|
93 |
+
data['lengths'] = sampled_dataset.features.lengths
|
94 |
+
data.update({'output_' + t.name: t.data for t in sampled_dataset.outputs})
|
95 |
+
data.update({
|
96 |
+
'hint_' + t.name: t.data for t in sampled_dataset.features.hints})
|
97 |
+
self._instantiated_dataset = data
|
98 |
+
|
99 |
+
def _info(self) -> tfds.core.DatasetInfo:
|
100 |
+
if tf.io.gfile.exists(self.data_dir):
|
101 |
+
info = tfds.core.DatasetInfo(builder=self)
|
102 |
+
info.read_from_directory(self.data_dir)
|
103 |
+
return info
|
104 |
+
|
105 |
+
if (self._instantiated_dataset_name != self._builder_config.name
|
106 |
+
or self._instantiated_dataset_split != self._builder_config.split): # pytype: disable=attribute-error # always-use-return-annotations
|
107 |
+
self._create_data(single_sample=True)
|
108 |
+
|
109 |
+
data = {k: _correct_axis_filtering(v, 0, k)
|
110 |
+
for k, v in self._instantiated_dataset.items()}
|
111 |
+
data_info = {
|
112 |
+
k: tfds.features.Tensor(shape=v.shape, dtype=tf.dtypes.as_dtype(
|
113 |
+
v.dtype)) for k, v in data.items()}
|
114 |
+
return tfds.core.DatasetInfo(
|
115 |
+
builder=self,
|
116 |
+
features=tfds.features.FeaturesDict(data_info),
|
117 |
+
)
|
118 |
+
|
119 |
+
def _split_generators(self, dl_manager: tfds.download.DownloadManager):
|
120 |
+
"""Download the data and define splits."""
|
121 |
+
if (self._instantiated_dataset_name != self._builder_config.name
|
122 |
+
or self._instantiated_dataset_split != self._builder_config.split): # pytype: disable=attribute-error # always-use-return-annotations
|
123 |
+
self._create_data(single_sample=False)
|
124 |
+
self._instantiated_dataset_name = self._builder_config.name
|
125 |
+
self._instantiated_dataset_split = self._builder_config.split # pytype: disable=attribute-error # always-use-return-annotations
|
126 |
+
return {self._builder_config.split: self._generate_examples()} # pytype: disable=attribute-error # always-use-return-annotations
|
127 |
+
|
128 |
+
def _generate_examples(self):
|
129 |
+
"""Generator of examples for each split."""
|
130 |
+
algorithm_name = '_'.join(self._builder_config.name.split('_')[:-1])
|
131 |
+
for i in range(self._num_samples(algorithm_name)):
|
132 |
+
data = {k: _correct_axis_filtering(v, i, k)
|
133 |
+
for k, v in self._instantiated_dataset.items()}
|
134 |
+
yield str(i), data
|
135 |
+
|
136 |
+
|
137 |
+
def _get_clrs_file_name():
|
138 |
+
return f'CLRS30_v{CLRSDataset.VERSION}.tar.gz'
|
139 |
+
|
140 |
+
|
141 |
+
def get_dataset_gcp_url():
|
142 |
+
return f'https://storage.googleapis.com/dm-clrs/{_get_clrs_file_name()}'
|
143 |
+
|
144 |
+
|
145 |
+
def get_clrs_folder():
|
146 |
+
return f'CLRS30_v{CLRSDataset.VERSION}'
|
147 |
+
|
148 |
+
|
149 |
+
def _preprocess(data_point, algorithm=None):
|
150 |
+
"""Convert sampled inputs into DataPoints."""
|
151 |
+
inputs = []
|
152 |
+
outputs = []
|
153 |
+
hints = []
|
154 |
+
lengths = None
|
155 |
+
|
156 |
+
for name, data in data_point.items():
|
157 |
+
if name == 'lengths':
|
158 |
+
lengths = data
|
159 |
+
continue
|
160 |
+
data_point_name = name.split('_')
|
161 |
+
name = '_'.join(data_point_name[1:])
|
162 |
+
(stage, location, dp_type) = specs.SPECS[algorithm][name]
|
163 |
+
assert stage == data_point_name[0]
|
164 |
+
if stage == specs.Stage.HINT:
|
165 |
+
data = tf.experimental.numpy.swapaxes(data, 0, 1)
|
166 |
+
dp = probing.DataPoint(name, location, dp_type, data)
|
167 |
+
if stage == specs.Stage.INPUT:
|
168 |
+
inputs.append(dp)
|
169 |
+
elif stage == specs.Stage.OUTPUT:
|
170 |
+
outputs.append(dp)
|
171 |
+
else:
|
172 |
+
hints.append(dp)
|
173 |
+
return samplers.Feedback(
|
174 |
+
samplers.Features(tuple(inputs), tuple(hints), lengths), tuple(outputs))
|
175 |
+
|
176 |
+
|
177 |
+
def create_dataset(folder, algorithm, split, batch_size):
|
178 |
+
dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}',
|
179 |
+
data_dir=folder, split=split)
|
180 |
+
num_samples = len(dataset) # Must be done here for correct size
|
181 |
+
dataset = dataset.repeat()
|
182 |
+
dataset = dataset.batch(batch_size)
|
183 |
+
return (dataset.map(lambda d: _preprocess(d, algorithm=algorithm)),
|
184 |
+
num_samples,
|
185 |
+
specs.SPECS[algorithm])
|
186 |
+
|
187 |
+
|
188 |
+
def _copy_hint(source, dest, i, start_source, start_dest, to_add):
|
189 |
+
"""Copy from full-sample hint to a hint chunk."""
|
190 |
+
assert np.all(dest[start_dest:, i:] == 0)
|
191 |
+
assert start_dest < dest.shape[0]
|
192 |
+
assert start_dest + to_add <= dest.shape[0]
|
193 |
+
assert start_source < source.shape[0]
|
194 |
+
assert start_source + to_add <= source.shape[0]
|
195 |
+
dest[start_dest:start_dest+to_add, i] = source[
|
196 |
+
start_source:start_source+to_add, i]
|
197 |
+
return dest
|
198 |
+
|
199 |
+
|
200 |
+
def _copy_io(source, dest, i, start_dest, to_add):
|
201 |
+
"""Copy from an input or output to an input or output chunk."""
|
202 |
+
assert np.all(dest[start_dest:, i:] == 0)
|
203 |
+
dest[start_dest:start_dest+to_add, i] = source[i]
|
204 |
+
return dest
|
205 |
+
|
206 |
+
|
207 |
+
def chunkify(dataset: Iterator[samplers.Feedback], chunk_length: int):
|
208 |
+
"""Generator of fixed-length chunks from full-trajectory samples.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
dataset: full-sample dataset as numpy iterator.
|
212 |
+
chunk_length: time length of chunks.
|
213 |
+
Yields:
|
214 |
+
Fixed-timelength chunks of data. Each tensor of inputs, hints and outputs
|
215 |
+
has dimensions chunk_length x batch_size x ... Samples are not time-padded,
|
216 |
+
after the end of one sample immediately comes the next. Since different
|
217 |
+
samples can have different time lengths, the beginnings and ends of samples
|
218 |
+
within a batch do not need to coincide. For this reason, the chunked
|
219 |
+
dataset features include two chunk_length x batch_size int tensors,
|
220 |
+
`is_first` and `is_last`, that mark the beginning and end of each sample.
|
221 |
+
For example, if `chunk_legnth`==6 and `batch_size`==2 and the first
|
222 |
+
full-sample batch had one sample of length 3 and one of length 5,
|
223 |
+
we would have a first chunked batch with the following `is_first` and
|
224 |
+
`is_last` tensors:
|
225 |
+
|
226 |
+
is_first = [[1, 1] is_last = [[0, 0] ( sample id [[0 1]
|
227 |
+
[0, 0] [0, 0] [0 1]
|
228 |
+
[0, 0] [1, 0] [0 1]
|
229 |
+
[1, 0] [0, 0] [2 1]
|
230 |
+
[0, 0] [0, 1] [2 1]
|
231 |
+
[0, 1]] [0, 0]] [2 3]] )
|
232 |
+
|
233 |
+
while the data in the inputs, outputs and hints tensors would correspond
|
234 |
+
to samples as identified by the sample_id indicated above for reference.
|
235 |
+
Notice that, while in the full-sample dataset inputs and outputs have
|
236 |
+
no time dimension, here they do; the input and output tensors are simply
|
237 |
+
repeated along each sample's time length.
|
238 |
+
"""
|
239 |
+
def _get_batch():
|
240 |
+
d = next(dataset)
|
241 |
+
return (d.features.inputs, d.features.hints, d.outputs,
|
242 |
+
d.features.lengths.astype(int))
|
243 |
+
|
244 |
+
inputs, hints, outputs, lengths = _get_batch()
|
245 |
+
for inp in inputs:
|
246 |
+
if inp.location in [specs.Location.NODE, specs.Location.EDGE]:
|
247 |
+
batch_size = inp.data.shape[0]
|
248 |
+
break
|
249 |
+
|
250 |
+
io_chunk = lambda x: np.zeros((chunk_length,) + x.shape, dtype=x.dtype)
|
251 |
+
chunk_inputs = jax.tree_util.tree_map(io_chunk, inputs)
|
252 |
+
chunk_outputs = jax.tree_util.tree_map(io_chunk, outputs)
|
253 |
+
|
254 |
+
hint_chunk = lambda x: np.zeros((chunk_length,) + x.shape[1:], dtype=x.dtype)
|
255 |
+
chunk_hints = jax.tree_util.tree_map(hint_chunk, hints)
|
256 |
+
|
257 |
+
inputs = [inputs]
|
258 |
+
hints = [hints]
|
259 |
+
outputs = [outputs]
|
260 |
+
left = [lengths.copy()]
|
261 |
+
lengths = [lengths.copy()]
|
262 |
+
|
263 |
+
while True:
|
264 |
+
# Create a new empty chunk
|
265 |
+
chunk_inputs = jax.tree_util.tree_map(np.zeros_like, chunk_inputs)
|
266 |
+
chunk_hints = jax.tree_util.tree_map(np.zeros_like, chunk_hints)
|
267 |
+
chunk_outputs = jax.tree_util.tree_map(np.zeros_like, chunk_outputs)
|
268 |
+
start_mark = np.zeros((chunk_length, batch_size), dtype=int)
|
269 |
+
end_mark = np.zeros((chunk_length, batch_size), dtype=int)
|
270 |
+
|
271 |
+
# Get enough data batches to fill the new chunk
|
272 |
+
while np.any(np.sum(left, axis=0) < chunk_length):
|
273 |
+
inp, hh, out, ll = _get_batch()
|
274 |
+
inputs.append(inp)
|
275 |
+
hints.append(hh)
|
276 |
+
outputs.append(out)
|
277 |
+
left.append(ll.copy())
|
278 |
+
lengths.append(ll.copy())
|
279 |
+
|
280 |
+
# Fill the chunk, one batch element at a time
|
281 |
+
for i in range(batch_size):
|
282 |
+
total, idx = 0, 0
|
283 |
+
while total < chunk_length:
|
284 |
+
to_add = min(left[idx][i], chunk_length - total)
|
285 |
+
if to_add:
|
286 |
+
start = lengths[idx][i] - left[idx][i]
|
287 |
+
assert start >= 0
|
288 |
+
f_io = functools.partial(_copy_io, i=i, start_dest=total,
|
289 |
+
to_add=to_add)
|
290 |
+
chunk_inputs = jax.tree_util.tree_map(f_io, inputs[idx], chunk_inputs)
|
291 |
+
chunk_outputs = jax.tree_util.tree_map(f_io, outputs[idx],
|
292 |
+
chunk_outputs)
|
293 |
+
f_hint = functools.partial(_copy_hint, i=i, start_source=start,
|
294 |
+
start_dest=total, to_add=to_add)
|
295 |
+
chunk_hints = jax.tree_util.tree_map(f_hint, hints[idx], chunk_hints)
|
296 |
+
if start == 0:
|
297 |
+
start_mark[total, i] = 1
|
298 |
+
total += to_add
|
299 |
+
left[idx][i] -= to_add
|
300 |
+
assert left[idx][i] >= 0
|
301 |
+
if left[idx][i] == 0:
|
302 |
+
end_mark[total - 1, i] = 1
|
303 |
+
idx += 1
|
304 |
+
assert total == chunk_length
|
305 |
+
|
306 |
+
while left and np.all(left[0] == 0):
|
307 |
+
inputs.pop(0)
|
308 |
+
hints.pop(0)
|
309 |
+
outputs.pop(0)
|
310 |
+
left.pop(0)
|
311 |
+
lengths.pop(0)
|
312 |
+
|
313 |
+
yield samplers.Feedback(
|
314 |
+
samplers.FeaturesChunked(chunk_inputs, chunk_hints,
|
315 |
+
start_mark, end_mark),
|
316 |
+
chunk_outputs)
|
317 |
+
|
318 |
+
|
319 |
+
def create_chunked_dataset(folder, algorithm, split, batch_size, chunk_length):
|
320 |
+
dataset = tfds.load(f'clrs_dataset/{algorithm}_{split}',
|
321 |
+
data_dir=folder, split=split)
|
322 |
+
dataset = dataset.repeat()
|
323 |
+
dataset = dataset.batch(batch_size)
|
324 |
+
dataset = dataset.map(lambda d: _preprocess(d, algorithm=algorithm))
|
325 |
+
dataset = dataset.as_numpy_iterator()
|
326 |
+
return chunkify(dataset, chunk_length), specs.SPECS[algorithm]
|
benchmarks/CLRS/env/dataset_test.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Unit tests for `dataset.py`."""
|
17 |
+
|
18 |
+
from typing import Generator, List
|
19 |
+
|
20 |
+
from absl.testing import absltest
|
21 |
+
from absl.testing import parameterized
|
22 |
+
|
23 |
+
from clrs._src import dataset
|
24 |
+
from clrs._src import samplers
|
25 |
+
from clrs._src import specs
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
_Array = np.ndarray
|
29 |
+
|
30 |
+
|
31 |
+
def _stack_to_shortest(x: List[_Array]) -> _Array:
|
32 |
+
min_len = min(map(len, x))
|
33 |
+
return np.array([a[:min_len] for a in x])
|
34 |
+
|
35 |
+
|
36 |
+
def _make_sampler(algo: str) -> samplers.Sampler:
|
37 |
+
sampler, _ = samplers.build_sampler(
|
38 |
+
algo,
|
39 |
+
seed=samplers.CLRS30['val']['seed'],
|
40 |
+
num_samples=samplers.CLRS30['val']['num_samples'],
|
41 |
+
length=samplers.CLRS30['val']['length'],
|
42 |
+
)
|
43 |
+
return sampler
|
44 |
+
|
45 |
+
|
46 |
+
def _make_iterable_sampler(
|
47 |
+
algo: str, batch_size: int) -> Generator[samplers.Feedback, None, None]:
|
48 |
+
sampler = _make_sampler(algo)
|
49 |
+
while True:
|
50 |
+
yield sampler.next(batch_size)
|
51 |
+
|
52 |
+
|
53 |
+
class DatasetTest(parameterized.TestCase):
|
54 |
+
|
55 |
+
@parameterized.product(
|
56 |
+
name=specs.CLRS_30_ALGS[:5],
|
57 |
+
chunk_length=[20, 50])
|
58 |
+
def test_chunkify(self, name: str, chunk_length: int):
|
59 |
+
"""Test that samples are concatenated and split in chunks correctly."""
|
60 |
+
batch_size = 8
|
61 |
+
|
62 |
+
ds = _make_iterable_sampler(name, batch_size)
|
63 |
+
chunked_ds = dataset.chunkify(
|
64 |
+
_make_iterable_sampler(name, batch_size),
|
65 |
+
chunk_length)
|
66 |
+
|
67 |
+
samples = [next(ds) for _ in range(20)]
|
68 |
+
cum_lengths = np.cumsum([s.features.lengths for s in samples], axis=0)
|
69 |
+
n_chunks = np.amax(cum_lengths[-1]).astype(int) // chunk_length + 1
|
70 |
+
chunks = [next(chunked_ds) for _ in range(n_chunks)]
|
71 |
+
|
72 |
+
# Check correctness of `is_first` and `is_last` markers
|
73 |
+
start_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate(
|
74 |
+
[c.features.is_first for c in chunks]).T]).T
|
75 |
+
end_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate(
|
76 |
+
[c.features.is_last for c in chunks]).T]).T
|
77 |
+
assert len(start_idx) >= len(cum_lengths)
|
78 |
+
start_idx = start_idx[:len(cum_lengths)]
|
79 |
+
assert len(end_idx) >= len(cum_lengths)
|
80 |
+
end_idx = end_idx[:len(cum_lengths)]
|
81 |
+
|
82 |
+
np.testing.assert_equal(start_idx[0], 0)
|
83 |
+
np.testing.assert_array_equal(cum_lengths - 1, end_idx)
|
84 |
+
np.testing.assert_array_equal(cum_lengths[:-1], start_idx[1:])
|
85 |
+
|
86 |
+
# Check that inputs, outputs and hints have been copied correctly
|
87 |
+
all_input = np.concatenate([c.features.inputs[0].data for c in chunks])
|
88 |
+
all_output = np.concatenate([c.outputs[0].data for c in chunks])
|
89 |
+
all_hint = np.concatenate([c.features.hints[0].data for c in chunks])
|
90 |
+
for i in range(batch_size):
|
91 |
+
length0 = int(samples[0].features.lengths[i])
|
92 |
+
length1 = int(samples[1].features.lengths[i])
|
93 |
+
# Check first sample
|
94 |
+
np.testing.assert_array_equal(
|
95 |
+
all_input[:length0, i],
|
96 |
+
np.tile(samples[0].features.inputs[0].data[i], [length0, 1]))
|
97 |
+
np.testing.assert_array_equal(
|
98 |
+
all_output[:length0, i],
|
99 |
+
np.tile(samples[0].outputs[0].data[i], [length0, 1]))
|
100 |
+
np.testing.assert_array_equal(
|
101 |
+
all_hint[:length0, i],
|
102 |
+
samples[0].features.hints[0].data[:length0, i])
|
103 |
+
# Check second sample
|
104 |
+
np.testing.assert_array_equal(
|
105 |
+
all_input[length0:length0 + length1, i],
|
106 |
+
np.tile(samples[1].features.inputs[0].data[i], [length1, 1]))
|
107 |
+
np.testing.assert_array_equal(
|
108 |
+
all_output[length0:length0 + length1, i],
|
109 |
+
np.tile(samples[1].outputs[0].data[i], [length1, 1]))
|
110 |
+
np.testing.assert_array_equal(
|
111 |
+
all_hint[length0:length0 + length1, i],
|
112 |
+
samples[1].features.hints[0].data[:length1, i])
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == '__main__':
|
116 |
+
absltest.main()
|
benchmarks/CLRS/env/decoders.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""decoders utilities."""
|
16 |
+
|
17 |
+
import functools
|
18 |
+
from typing import Dict, Optional
|
19 |
+
|
20 |
+
import chex
|
21 |
+
from clrs._src import probing
|
22 |
+
from clrs._src import specs
|
23 |
+
import haiku as hk
|
24 |
+
import jax
|
25 |
+
import jax.numpy as jnp
|
26 |
+
|
27 |
+
_Array = chex.Array
|
28 |
+
_DataPoint = probing.DataPoint
|
29 |
+
_Location = specs.Location
|
30 |
+
_Spec = specs.Spec
|
31 |
+
_Stage = specs.Stage
|
32 |
+
_Type = specs.Type
|
33 |
+
|
34 |
+
|
35 |
+
def log_sinkhorn(x: _Array, steps: int, temperature: float, zero_diagonal: bool,
|
36 |
+
noise_rng_key: Optional[_Array]) -> _Array:
|
37 |
+
"""Sinkhorn operator in log space, to postprocess permutation pointer logits.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
x: input of shape [..., n, n], a batch of square matrices.
|
41 |
+
steps: number of iterations.
|
42 |
+
temperature: temperature parameter (as temperature approaches zero, the
|
43 |
+
output approaches a permutation matrix).
|
44 |
+
zero_diagonal: whether to force the diagonal logits towards -inf.
|
45 |
+
noise_rng_key: key to add Gumbel noise.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
Elementwise logarithm of a doubly-stochastic matrix (a matrix with
|
49 |
+
non-negative elements whose rows and columns sum to 1).
|
50 |
+
"""
|
51 |
+
assert x.ndim >= 2
|
52 |
+
assert x.shape[-1] == x.shape[-2]
|
53 |
+
if noise_rng_key is not None:
|
54 |
+
# Add standard Gumbel noise (see https://arxiv.org/abs/1802.08665)
|
55 |
+
noise = -jnp.log(-jnp.log(jax.random.uniform(noise_rng_key,
|
56 |
+
x.shape) + 1e-12) + 1e-12)
|
57 |
+
x = x + noise
|
58 |
+
x /= temperature
|
59 |
+
if zero_diagonal:
|
60 |
+
x = x - 1e6 * jnp.eye(x.shape[-1])
|
61 |
+
for _ in range(steps):
|
62 |
+
x = jax.nn.log_softmax(x, axis=-1)
|
63 |
+
x = jax.nn.log_softmax(x, axis=-2)
|
64 |
+
return x
|
65 |
+
|
66 |
+
|
67 |
+
def construct_decoders(loc: str, t: str, hidden_dim: int, nb_dims: int,
|
68 |
+
name: str):
|
69 |
+
"""Constructs decoders."""
|
70 |
+
linear = functools.partial(hk.Linear, name=f"{name}_dec_linear")
|
71 |
+
if loc == _Location.NODE:
|
72 |
+
# Node decoders.
|
73 |
+
if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
|
74 |
+
decoders = (linear(1),)
|
75 |
+
elif t == _Type.CATEGORICAL:
|
76 |
+
decoders = (linear(nb_dims),)
|
77 |
+
elif t in [_Type.POINTER, _Type.PERMUTATION_POINTER]:
|
78 |
+
decoders = (linear(hidden_dim), linear(hidden_dim), linear(hidden_dim),
|
79 |
+
linear(1))
|
80 |
+
else:
|
81 |
+
raise ValueError(f"Invalid Type {t}")
|
82 |
+
|
83 |
+
elif loc == _Location.EDGE:
|
84 |
+
# Edge decoders.
|
85 |
+
if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
|
86 |
+
decoders = (linear(1), linear(1), linear(1))
|
87 |
+
elif t == _Type.CATEGORICAL:
|
88 |
+
decoders = (linear(nb_dims), linear(nb_dims), linear(nb_dims))
|
89 |
+
elif t == _Type.POINTER:
|
90 |
+
decoders = (linear(hidden_dim), linear(hidden_dim),
|
91 |
+
linear(hidden_dim), linear(hidden_dim), linear(1))
|
92 |
+
else:
|
93 |
+
raise ValueError(f"Invalid Type {t}")
|
94 |
+
|
95 |
+
elif loc == _Location.GRAPH:
|
96 |
+
# Graph decoders.
|
97 |
+
if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
|
98 |
+
decoders = (linear(1), linear(1))
|
99 |
+
elif t == _Type.CATEGORICAL:
|
100 |
+
decoders = (linear(nb_dims), linear(nb_dims))
|
101 |
+
elif t == _Type.POINTER:
|
102 |
+
decoders = (linear(1), linear(1),
|
103 |
+
linear(1))
|
104 |
+
else:
|
105 |
+
raise ValueError(f"Invalid Type {t}")
|
106 |
+
|
107 |
+
else:
|
108 |
+
raise ValueError(f"Invalid Location {loc}")
|
109 |
+
|
110 |
+
return decoders
|
111 |
+
|
112 |
+
|
113 |
+
def construct_diff_decoders(name: str):
|
114 |
+
"""Constructs diff decoders."""
|
115 |
+
linear = functools.partial(hk.Linear, name=f"{name}_diffdec_linear")
|
116 |
+
decoders = {}
|
117 |
+
decoders[_Location.NODE] = linear(1)
|
118 |
+
decoders[_Location.EDGE] = (linear(1), linear(1), linear(1))
|
119 |
+
decoders[_Location.GRAPH] = (linear(1), linear(1))
|
120 |
+
|
121 |
+
return decoders
|
122 |
+
|
123 |
+
|
124 |
+
def postprocess(spec: _Spec, preds: Dict[str, _Array],
|
125 |
+
sinkhorn_temperature: float,
|
126 |
+
sinkhorn_steps: int,
|
127 |
+
hard: bool) -> Dict[str, _DataPoint]:
|
128 |
+
"""Postprocesses decoder output.
|
129 |
+
|
130 |
+
This is done on outputs in order to score performance, and on hints in
|
131 |
+
order to score them but also in order to feed them back to the model.
|
132 |
+
At scoring time, the postprocessing mode is "hard", logits will be
|
133 |
+
arg-maxed and masks will be thresholded. However, for the case of the hints
|
134 |
+
that are fed back in the model, the postprocessing can be hard or soft,
|
135 |
+
depending on whether we want to let gradients flow through them or not.
|
136 |
+
|
137 |
+
Args:
|
138 |
+
spec: The spec of the algorithm whose outputs/hints we are postprocessing.
|
139 |
+
preds: Output and/or hint predictions, as produced by decoders.
|
140 |
+
sinkhorn_temperature: Parameter for the sinkhorn operator on permutation
|
141 |
+
pointers.
|
142 |
+
sinkhorn_steps: Parameter for the sinkhorn operator on permutation
|
143 |
+
pointers.
|
144 |
+
hard: whether to do hard postprocessing, which involves argmax for
|
145 |
+
MASK_ONE, CATEGORICAL and POINTERS, thresholding for MASK, and stop
|
146 |
+
gradient through for SCALAR. If False, soft postprocessing will be used,
|
147 |
+
with softmax, sigmoid and gradients allowed.
|
148 |
+
Returns:
|
149 |
+
The postprocessed `preds`. In "soft" post-processing, POINTER types will
|
150 |
+
change to SOFT_POINTER, so encoders know they do not need to be
|
151 |
+
pre-processed before feeding them back in.
|
152 |
+
"""
|
153 |
+
result = {}
|
154 |
+
for name in preds.keys():
|
155 |
+
_, loc, t = spec[name]
|
156 |
+
new_t = t
|
157 |
+
data = preds[name]
|
158 |
+
if t == _Type.SCALAR:
|
159 |
+
if hard:
|
160 |
+
data = jax.lax.stop_gradient(data)
|
161 |
+
elif t == _Type.MASK:
|
162 |
+
if hard:
|
163 |
+
data = (data > 0.0) * 1.0
|
164 |
+
else:
|
165 |
+
data = jax.nn.sigmoid(data)
|
166 |
+
elif t in [_Type.MASK_ONE, _Type.CATEGORICAL]:
|
167 |
+
cat_size = data.shape[-1]
|
168 |
+
if hard:
|
169 |
+
best = jnp.argmax(data, -1)
|
170 |
+
data = hk.one_hot(best, cat_size)
|
171 |
+
else:
|
172 |
+
data = jax.nn.softmax(data, axis=-1)
|
173 |
+
elif t == _Type.POINTER:
|
174 |
+
if hard:
|
175 |
+
data = jnp.argmax(data, -1).astype(float)
|
176 |
+
else:
|
177 |
+
data = jax.nn.softmax(data, -1)
|
178 |
+
new_t = _Type.SOFT_POINTER
|
179 |
+
elif t == _Type.PERMUTATION_POINTER:
|
180 |
+
# Convert the matrix of logits to a doubly stochastic matrix.
|
181 |
+
data = log_sinkhorn(
|
182 |
+
x=data,
|
183 |
+
steps=sinkhorn_steps,
|
184 |
+
temperature=sinkhorn_temperature,
|
185 |
+
zero_diagonal=True,
|
186 |
+
noise_rng_key=None)
|
187 |
+
data = jnp.exp(data)
|
188 |
+
if hard:
|
189 |
+
data = jax.nn.one_hot(jnp.argmax(data, axis=-1), data.shape[-1])
|
190 |
+
else:
|
191 |
+
raise ValueError("Invalid type")
|
192 |
+
result[name] = probing.DataPoint(
|
193 |
+
name=name, location=loc, type_=new_t, data=data)
|
194 |
+
|
195 |
+
return result
|
196 |
+
|
197 |
+
|
198 |
+
def decode_fts(
|
199 |
+
decoders,
|
200 |
+
spec: _Spec,
|
201 |
+
h_t: _Array,
|
202 |
+
adj_mat: _Array,
|
203 |
+
edge_fts: _Array,
|
204 |
+
graph_fts: _Array,
|
205 |
+
inf_bias: bool,
|
206 |
+
inf_bias_edge: bool,
|
207 |
+
repred: bool,
|
208 |
+
):
|
209 |
+
"""Decodes node, edge and graph features."""
|
210 |
+
output_preds = {}
|
211 |
+
hint_preds = {}
|
212 |
+
|
213 |
+
for name in decoders:
|
214 |
+
decoder = decoders[name]
|
215 |
+
stage, loc, t = spec[name]
|
216 |
+
|
217 |
+
if loc == _Location.NODE:
|
218 |
+
preds = _decode_node_fts(decoder, t, h_t, edge_fts, adj_mat,
|
219 |
+
inf_bias, repred)
|
220 |
+
elif loc == _Location.EDGE:
|
221 |
+
preds = _decode_edge_fts(decoder, t, h_t, edge_fts, adj_mat,
|
222 |
+
inf_bias_edge)
|
223 |
+
elif loc == _Location.GRAPH:
|
224 |
+
preds = _decode_graph_fts(decoder, t, h_t, graph_fts)
|
225 |
+
else:
|
226 |
+
raise ValueError("Invalid output type")
|
227 |
+
|
228 |
+
if stage == _Stage.OUTPUT:
|
229 |
+
output_preds[name] = preds
|
230 |
+
elif stage == _Stage.HINT:
|
231 |
+
hint_preds[name] = preds
|
232 |
+
else:
|
233 |
+
raise ValueError(f"Found unexpected decoder {name}")
|
234 |
+
|
235 |
+
return hint_preds, output_preds
|
236 |
+
|
237 |
+
|
238 |
+
def _decode_node_fts(decoders, t: str, h_t: _Array, edge_fts: _Array,
|
239 |
+
adj_mat: _Array, inf_bias: bool, repred: bool) -> _Array:
|
240 |
+
"""Decodes node features."""
|
241 |
+
|
242 |
+
if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
|
243 |
+
preds = jnp.squeeze(decoders[0](h_t), -1)
|
244 |
+
elif t == _Type.CATEGORICAL:
|
245 |
+
preds = decoders[0](h_t)
|
246 |
+
elif t in [_Type.POINTER, _Type.PERMUTATION_POINTER]:
|
247 |
+
p_1 = decoders[0](h_t)
|
248 |
+
p_2 = decoders[1](h_t)
|
249 |
+
p_3 = decoders[2](edge_fts)
|
250 |
+
|
251 |
+
p_e = jnp.expand_dims(p_2, -2) + p_3
|
252 |
+
p_m = jnp.maximum(jnp.expand_dims(p_1, -2),
|
253 |
+
jnp.transpose(p_e, (0, 2, 1, 3)))
|
254 |
+
|
255 |
+
preds = jnp.squeeze(decoders[3](p_m), -1)
|
256 |
+
|
257 |
+
if inf_bias:
|
258 |
+
per_batch_min = jnp.min(preds, axis=range(1, preds.ndim), keepdims=True)
|
259 |
+
preds = jnp.where(adj_mat > 0.5,
|
260 |
+
preds,
|
261 |
+
jnp.minimum(-1.0, per_batch_min - 1.0))
|
262 |
+
if t == _Type.PERMUTATION_POINTER:
|
263 |
+
if repred: # testing or validation, no Gumbel noise
|
264 |
+
preds = log_sinkhorn(
|
265 |
+
x=preds, steps=10, temperature=0.1,
|
266 |
+
zero_diagonal=True, noise_rng_key=None)
|
267 |
+
else: # training, add Gumbel noise
|
268 |
+
preds = log_sinkhorn(
|
269 |
+
x=preds, steps=10, temperature=0.1,
|
270 |
+
zero_diagonal=True, noise_rng_key=hk.next_rng_key())
|
271 |
+
else:
|
272 |
+
raise ValueError("Invalid output type")
|
273 |
+
|
274 |
+
return preds
|
275 |
+
|
276 |
+
|
277 |
+
def _decode_edge_fts(decoders, t: str, h_t: _Array, edge_fts: _Array,
|
278 |
+
adj_mat: _Array, inf_bias_edge: bool) -> _Array:
|
279 |
+
"""Decodes edge features."""
|
280 |
+
|
281 |
+
pred_1 = decoders[0](h_t)
|
282 |
+
pred_2 = decoders[1](h_t)
|
283 |
+
pred_e = decoders[2](edge_fts)
|
284 |
+
pred = (jnp.expand_dims(pred_1, -2) + jnp.expand_dims(pred_2, -3) + pred_e)
|
285 |
+
if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
|
286 |
+
preds = jnp.squeeze(pred, -1)
|
287 |
+
elif t == _Type.CATEGORICAL:
|
288 |
+
preds = pred
|
289 |
+
elif t == _Type.POINTER:
|
290 |
+
pred_2 = decoders[3](h_t)
|
291 |
+
|
292 |
+
p_m = jnp.maximum(jnp.expand_dims(pred, -2),
|
293 |
+
jnp.expand_dims(
|
294 |
+
jnp.expand_dims(pred_2, -3), -3))
|
295 |
+
|
296 |
+
preds = jnp.squeeze(decoders[4](p_m), -1)
|
297 |
+
else:
|
298 |
+
raise ValueError("Invalid output type")
|
299 |
+
if inf_bias_edge and t in [_Type.MASK, _Type.MASK_ONE]:
|
300 |
+
per_batch_min = jnp.min(preds, axis=range(1, preds.ndim), keepdims=True)
|
301 |
+
preds = jnp.where(adj_mat > 0.5,
|
302 |
+
preds,
|
303 |
+
jnp.minimum(-1.0, per_batch_min - 1.0))
|
304 |
+
|
305 |
+
return preds
|
306 |
+
|
307 |
+
|
308 |
+
def _decode_graph_fts(decoders, t: str, h_t: _Array,
|
309 |
+
graph_fts: _Array) -> _Array:
|
310 |
+
"""Decodes graph features."""
|
311 |
+
|
312 |
+
gr_emb = jnp.max(h_t, axis=-2)
|
313 |
+
pred_n = decoders[0](gr_emb)
|
314 |
+
pred_g = decoders[1](graph_fts)
|
315 |
+
pred = pred_n + pred_g
|
316 |
+
if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]:
|
317 |
+
preds = jnp.squeeze(pred, -1)
|
318 |
+
elif t == _Type.CATEGORICAL:
|
319 |
+
preds = pred
|
320 |
+
elif t == _Type.POINTER:
|
321 |
+
pred_2 = decoders[2](h_t)
|
322 |
+
ptr_p = jnp.expand_dims(pred, 1) + jnp.transpose(pred_2, (0, 2, 1))
|
323 |
+
preds = jnp.squeeze(ptr_p, 1)
|
324 |
+
else:
|
325 |
+
raise ValueError("Invalid output type")
|
326 |
+
|
327 |
+
return preds
|
328 |
+
|
329 |
+
|
330 |
+
def maybe_decode_diffs(
|
331 |
+
diff_decoders,
|
332 |
+
h_t: _Array,
|
333 |
+
edge_fts: _Array,
|
334 |
+
graph_fts: _Array,
|
335 |
+
decode_diffs: bool,
|
336 |
+
) -> Optional[Dict[str, _Array]]:
|
337 |
+
"""Optionally decodes node, edge and graph diffs."""
|
338 |
+
|
339 |
+
if decode_diffs:
|
340 |
+
preds = {}
|
341 |
+
node = _Location.NODE
|
342 |
+
edge = _Location.EDGE
|
343 |
+
graph = _Location.GRAPH
|
344 |
+
preds[node] = _decode_node_diffs(diff_decoders[node], h_t)
|
345 |
+
preds[edge] = _decode_edge_diffs(diff_decoders[edge], h_t, edge_fts)
|
346 |
+
preds[graph] = _decode_graph_diffs(diff_decoders[graph], h_t, graph_fts)
|
347 |
+
|
348 |
+
else:
|
349 |
+
preds = None
|
350 |
+
|
351 |
+
return preds
|
352 |
+
|
353 |
+
|
354 |
+
def _decode_node_diffs(decoders, h_t: _Array) -> _Array:
|
355 |
+
"""Decodes node diffs."""
|
356 |
+
return jnp.squeeze(decoders(h_t), -1)
|
357 |
+
|
358 |
+
|
359 |
+
def _decode_edge_diffs(decoders, h_t: _Array, edge_fts: _Array) -> _Array:
|
360 |
+
"""Decodes edge diffs."""
|
361 |
+
|
362 |
+
e_pred_1 = decoders[0](h_t)
|
363 |
+
e_pred_2 = decoders[1](h_t)
|
364 |
+
e_pred_e = decoders[2](edge_fts)
|
365 |
+
preds = jnp.squeeze(
|
366 |
+
jnp.expand_dims(e_pred_1, -1) + jnp.expand_dims(e_pred_2, -2) + e_pred_e,
|
367 |
+
-1,
|
368 |
+
)
|
369 |
+
|
370 |
+
return preds
|
371 |
+
|
372 |
+
|
373 |
+
def _decode_graph_diffs(decoders, h_t: _Array, graph_fts: _Array) -> _Array:
|
374 |
+
"""Decodes graph diffs."""
|
375 |
+
|
376 |
+
gr_emb = jnp.max(h_t, axis=-2)
|
377 |
+
g_pred_n = decoders[0](gr_emb)
|
378 |
+
g_pred_g = decoders[1](graph_fts)
|
379 |
+
preds = jnp.squeeze(g_pred_n + g_pred_g, -1)
|
380 |
+
|
381 |
+
return preds
|
benchmarks/CLRS/env/decoders_test.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Unit tests for `decoders.py`."""
|
17 |
+
|
18 |
+
from absl.testing import absltest
|
19 |
+
|
20 |
+
import chex
|
21 |
+
from clrs._src import decoders
|
22 |
+
import jax
|
23 |
+
import jax.numpy as jnp
|
24 |
+
|
25 |
+
|
26 |
+
class DecodersTest(absltest.TestCase):
|
27 |
+
|
28 |
+
def test_log_sinkhorn(self):
|
29 |
+
x = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
|
30 |
+
y = jnp.exp(decoders.log_sinkhorn(x, steps=10, temperature=1.0,
|
31 |
+
zero_diagonal=False,
|
32 |
+
noise_rng_key=None))
|
33 |
+
chex.assert_trees_all_close(jnp.sum(y, axis=-1), 1., atol=1e-4)
|
34 |
+
chex.assert_trees_all_close(jnp.sum(y, axis=-2), 1., atol=1e-4)
|
35 |
+
|
36 |
+
def test_log_sinkhorn_zero_diagonal(self):
|
37 |
+
x = jax.random.normal(jax.random.PRNGKey(42), (10, 10))
|
38 |
+
y = jnp.exp(decoders.log_sinkhorn(x, steps=10, temperature=1.0,
|
39 |
+
zero_diagonal=True,
|
40 |
+
noise_rng_key=None))
|
41 |
+
chex.assert_trees_all_close(jnp.sum(y, axis=-1), 1., atol=1e-4)
|
42 |
+
chex.assert_trees_all_close(jnp.sum(y, axis=-2), 1., atol=1e-4)
|
43 |
+
chex.assert_trees_all_close(jnp.sum(y.diagonal()), 0., atol=1e-4)
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == '__main__':
|
47 |
+
absltest.main()
|
benchmarks/CLRS/env/encoders.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Encoder utilities."""
|
16 |
+
|
17 |
+
import functools
|
18 |
+
import chex
|
19 |
+
from clrs._src import probing
|
20 |
+
from clrs._src import specs
|
21 |
+
import haiku as hk
|
22 |
+
import jax.numpy as jnp
|
23 |
+
|
24 |
+
_Array = chex.Array
|
25 |
+
_DataPoint = probing.DataPoint
|
26 |
+
_Location = specs.Location
|
27 |
+
_Spec = specs.Spec
|
28 |
+
_Stage = specs.Stage
|
29 |
+
_Type = specs.Type
|
30 |
+
|
31 |
+
|
32 |
+
def construct_encoders(stage: str, loc: str, t: str,
|
33 |
+
hidden_dim: int, init: str, name: str):
|
34 |
+
"""Constructs encoders."""
|
35 |
+
if init == 'xavier_on_scalars' and stage == _Stage.HINT and t == _Type.SCALAR:
|
36 |
+
initialiser = hk.initializers.TruncatedNormal(
|
37 |
+
stddev=1.0 / jnp.sqrt(hidden_dim))
|
38 |
+
elif init in ['default', 'xavier_on_scalars']:
|
39 |
+
initialiser = None
|
40 |
+
else:
|
41 |
+
raise ValueError(f'Encoder initialiser {init} not supported.')
|
42 |
+
linear = functools.partial(
|
43 |
+
hk.Linear,
|
44 |
+
w_init=initialiser,
|
45 |
+
name=f'{name}_enc_linear')
|
46 |
+
encoders = [linear(hidden_dim)]
|
47 |
+
if loc == _Location.EDGE and t == _Type.POINTER:
|
48 |
+
# Edge pointers need two-way encoders.
|
49 |
+
encoders.append(linear(hidden_dim))
|
50 |
+
|
51 |
+
return encoders
|
52 |
+
|
53 |
+
|
54 |
+
def preprocess(dp: _DataPoint, nb_nodes: int) -> _DataPoint:
|
55 |
+
"""Pre-process data point.
|
56 |
+
|
57 |
+
Make sure that the data is ready to be encoded into features.
|
58 |
+
If the data is of POINTER type, we expand the compressed index representation
|
59 |
+
to a full one-hot. But if the data is a SOFT_POINTER, the representation
|
60 |
+
is already expanded and we just overwrite the type as POINTER so that
|
61 |
+
it is treated as such for encoding.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
dp: A DataPoint to prepare for encoding.
|
65 |
+
nb_nodes: Number of nodes in the graph, necessary to expand pointers to
|
66 |
+
the right dimension.
|
67 |
+
Returns:
|
68 |
+
The datapoint, with data and possibly type modified.
|
69 |
+
"""
|
70 |
+
new_type = dp.type_
|
71 |
+
if dp.type_ == _Type.POINTER:
|
72 |
+
data = hk.one_hot(dp.data, nb_nodes)
|
73 |
+
else:
|
74 |
+
data = dp.data.astype(jnp.float32)
|
75 |
+
if dp.type_ == _Type.SOFT_POINTER:
|
76 |
+
new_type = _Type.POINTER
|
77 |
+
dp = probing.DataPoint(
|
78 |
+
name=dp.name, location=dp.location, type_=new_type, data=data)
|
79 |
+
|
80 |
+
return dp
|
81 |
+
|
82 |
+
|
83 |
+
def accum_adj_mat(dp: _DataPoint, adj_mat: _Array) -> _Array:
|
84 |
+
"""Accumulates adjacency matrix."""
|
85 |
+
if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER,
|
86 |
+
_Type.PERMUTATION_POINTER]:
|
87 |
+
adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.5)
|
88 |
+
elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK:
|
89 |
+
adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.0)
|
90 |
+
|
91 |
+
return (adj_mat > 0.).astype('float32') # pytype: disable=attribute-error # numpy-scalars
|
92 |
+
|
93 |
+
|
94 |
+
def accum_edge_fts(encoders, dp: _DataPoint, edge_fts: _Array) -> _Array:
|
95 |
+
"""Encodes and accumulates edge features."""
|
96 |
+
if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER,
|
97 |
+
_Type.PERMUTATION_POINTER]:
|
98 |
+
encoding = _encode_inputs(encoders, dp)
|
99 |
+
edge_fts += encoding
|
100 |
+
|
101 |
+
elif dp.location == _Location.EDGE:
|
102 |
+
encoding = _encode_inputs(encoders, dp)
|
103 |
+
if dp.type_ == _Type.POINTER:
|
104 |
+
# Aggregate pointer contributions across sender and receiver nodes.
|
105 |
+
encoding_2 = encoders[1](jnp.expand_dims(dp.data, -1))
|
106 |
+
edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2)
|
107 |
+
else:
|
108 |
+
edge_fts += encoding
|
109 |
+
|
110 |
+
return edge_fts
|
111 |
+
|
112 |
+
|
113 |
+
def accum_node_fts(encoders, dp: _DataPoint, node_fts: _Array) -> _Array:
|
114 |
+
"""Encodes and accumulates node features."""
|
115 |
+
is_pointer = (dp.type_ in [_Type.POINTER, _Type.PERMUTATION_POINTER])
|
116 |
+
if ((dp.location == _Location.NODE and not is_pointer) or
|
117 |
+
(dp.location == _Location.GRAPH and dp.type_ == _Type.POINTER)):
|
118 |
+
encoding = _encode_inputs(encoders, dp)
|
119 |
+
node_fts += encoding
|
120 |
+
|
121 |
+
return node_fts
|
122 |
+
|
123 |
+
|
124 |
+
def accum_graph_fts(encoders, dp: _DataPoint,
|
125 |
+
graph_fts: _Array) -> _Array:
|
126 |
+
"""Encodes and accumulates graph features."""
|
127 |
+
if dp.location == _Location.GRAPH and dp.type_ != _Type.POINTER:
|
128 |
+
encoding = _encode_inputs(encoders, dp)
|
129 |
+
graph_fts += encoding
|
130 |
+
|
131 |
+
return graph_fts
|
132 |
+
|
133 |
+
|
134 |
+
def _encode_inputs(encoders, dp: _DataPoint) -> _Array:
|
135 |
+
if dp.type_ == _Type.CATEGORICAL:
|
136 |
+
encoding = encoders[0](dp.data)
|
137 |
+
else:
|
138 |
+
encoding = encoders[0](jnp.expand_dims(dp.data, -1))
|
139 |
+
return encoding
|
benchmarks/CLRS/env/evaluation.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Model base classes and utilities."""
|
17 |
+
|
18 |
+
from typing import Dict, List, Tuple
|
19 |
+
import chex
|
20 |
+
from clrs._src import probing
|
21 |
+
from clrs._src import specs
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
|
25 |
+
_Array = chex.Array
|
26 |
+
Result = Dict[str, probing.DataPoint]
|
27 |
+
|
28 |
+
|
29 |
+
def fuse_perm_and_mask(perm: probing.DataPoint,
|
30 |
+
mask: probing.DataPoint) -> probing.DataPoint:
|
31 |
+
"""Replace permutation pointers active in the mask with self-pointers.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
perm: a node permutation_pointer; data shape is expected to be
|
35 |
+
[..., N, N], and ideally one-hot over the last two dimensions, although
|
36 |
+
this method does not check for one-hotness.
|
37 |
+
mask: a mask_one over nodes; data shape is expected to be
|
38 |
+
[..., N], and ideally one-hot over the last dimension, although
|
39 |
+
this method does not check for one-hotness.
|
40 |
+
Returns:
|
41 |
+
A node pointer with shape [..., N].
|
42 |
+
"""
|
43 |
+
assert perm.type_ == specs.Type.PERMUTATION_POINTER
|
44 |
+
assert perm.location == specs.Location.NODE
|
45 |
+
assert mask.name == perm.name + '_mask'
|
46 |
+
assert mask.type_ == specs.Type.MASK_ONE
|
47 |
+
assert mask.location == specs.Location.NODE
|
48 |
+
assert perm.data.shape[-1] == perm.data.shape[-2]
|
49 |
+
assert perm.data.shape[:-1] == mask.data.shape
|
50 |
+
data = np.where(mask.data > 0.5,
|
51 |
+
np.arange(perm.data.shape[-1]), # self-pointers
|
52 |
+
np.argmax(perm.data, axis=-1)) # original pointers
|
53 |
+
return probing.DataPoint(name=perm.name,
|
54 |
+
type_=specs.Type.POINTER,
|
55 |
+
location=perm.location,
|
56 |
+
data=data)
|
57 |
+
|
58 |
+
|
59 |
+
def _reduce_permutations_tuple(
|
60 |
+
targets: Tuple[probing.DataPoint, ...]) -> Tuple[probing.DataPoint, ...]:
|
61 |
+
"""Reduce node pointer + mask_one permutation to just node pointer."""
|
62 |
+
out_targets = []
|
63 |
+
n_perms = 0
|
64 |
+
i = 0
|
65 |
+
while i < len(targets):
|
66 |
+
truth = targets[i]
|
67 |
+
if truth.type_ != specs.Type.PERMUTATION_POINTER:
|
68 |
+
out_targets.append(truth)
|
69 |
+
i += 1
|
70 |
+
continue
|
71 |
+
truth_mask = targets[i + 1]
|
72 |
+
out_targets.append(fuse_perm_and_mask(truth, truth_mask))
|
73 |
+
i += 2
|
74 |
+
n_perms += 1
|
75 |
+
|
76 |
+
assert len(out_targets) == len(targets) - n_perms
|
77 |
+
return tuple(out_targets)
|
78 |
+
|
79 |
+
|
80 |
+
def _reduce_permutations_dict(predictions: Result) -> Result:
|
81 |
+
"""Reduce node pointer + mask_one permutation to just node pointer."""
|
82 |
+
out_preds = {}
|
83 |
+
n_perms = 0
|
84 |
+
for k, pred in predictions.items():
|
85 |
+
if (k.endswith('_mask') and k[:-5] in predictions and
|
86 |
+
predictions[k[:-5]].type_ == specs.Type.PERMUTATION_POINTER):
|
87 |
+
# This mask will be processed with its associated permutation datapoint
|
88 |
+
continue
|
89 |
+
if pred.type_ != specs.Type.PERMUTATION_POINTER:
|
90 |
+
out_preds[k] = pred
|
91 |
+
continue
|
92 |
+
pred_mask = predictions[k + '_mask']
|
93 |
+
out_preds[k] = fuse_perm_and_mask(pred, pred_mask)
|
94 |
+
n_perms += 1
|
95 |
+
|
96 |
+
assert len(out_preds) == len(predictions) - n_perms
|
97 |
+
return out_preds
|
98 |
+
|
99 |
+
|
100 |
+
def evaluate_hints(
|
101 |
+
hints: Tuple[probing.DataPoint, ...],
|
102 |
+
lengths: _Array,
|
103 |
+
hint_preds: List[Result],
|
104 |
+
) -> Dict[str, _Array]:
|
105 |
+
"""Evaluate hint predictions."""
|
106 |
+
evals = {}
|
107 |
+
hints = _reduce_permutations_tuple(hints)
|
108 |
+
hint_preds = [_reduce_permutations_dict(h) for h in hint_preds]
|
109 |
+
for truth in hints:
|
110 |
+
assert truth.name in hint_preds[0]
|
111 |
+
eval_along_time = [_evaluate(truth, p[truth.name],
|
112 |
+
idx=i+1, lengths=lengths)
|
113 |
+
for (i, p) in enumerate(hint_preds)]
|
114 |
+
evals[truth.name] = np.sum(
|
115 |
+
[x * np.sum(i+1 < lengths)
|
116 |
+
for i, x in enumerate(eval_along_time)]) / np.sum(lengths - 1)
|
117 |
+
evals[truth.name + '_along_time'] = np.array(eval_along_time)
|
118 |
+
|
119 |
+
# Unlike outputs, the hints sometimes include scalars, which don't have
|
120 |
+
# a meaningful eval score. So we don't compute a global 'hint score' as we
|
121 |
+
# do for outputs.
|
122 |
+
return evals
|
123 |
+
|
124 |
+
|
125 |
+
def evaluate(
|
126 |
+
outputs: Tuple[probing.DataPoint, ...],
|
127 |
+
predictions: Result,
|
128 |
+
) -> Dict[str, float]:
|
129 |
+
"""Evaluate output predictions."""
|
130 |
+
evals = {}
|
131 |
+
outputs = _reduce_permutations_tuple(outputs)
|
132 |
+
predictions = _reduce_permutations_dict(predictions)
|
133 |
+
for truth in outputs:
|
134 |
+
assert truth.name in predictions
|
135 |
+
pred = predictions[truth.name]
|
136 |
+
evals[truth.name] = _evaluate(truth, pred)
|
137 |
+
# Return a single scalar score that is the mean of all output scores.
|
138 |
+
evals['score'] = sum([v.item() for v in evals.values()]) / len(evals)
|
139 |
+
return evals
|
140 |
+
|
141 |
+
|
142 |
+
def _evaluate(truth, pred, idx=None, lengths=None):
|
143 |
+
"""Evaluate single prediction of hint or output."""
|
144 |
+
assert pred.name == truth.name
|
145 |
+
assert pred.location == truth.location
|
146 |
+
assert pred.type_ == truth.type_
|
147 |
+
|
148 |
+
if truth.type_ not in _EVAL_FN:
|
149 |
+
raise ValueError('Invalid type')
|
150 |
+
truth_data = truth.data
|
151 |
+
pred_data = pred.data
|
152 |
+
if idx is not None:
|
153 |
+
if np.all(idx >= lengths):
|
154 |
+
return 0.
|
155 |
+
truth_data = truth_data[idx][idx < lengths]
|
156 |
+
pred_data = pred_data[idx < lengths]
|
157 |
+
return _EVAL_FN[truth.type_](pred_data, truth_data)
|
158 |
+
|
159 |
+
|
160 |
+
def _eval_one(pred, truth):
|
161 |
+
mask = np.all(truth != specs.OutputClass.MASKED, axis=-1)
|
162 |
+
return np.sum(
|
163 |
+
(np.argmax(pred, -1) == np.argmax(truth, -1)) * mask) / np.sum(mask)
|
164 |
+
|
165 |
+
|
166 |
+
def _mask_fn(pred, truth):
|
167 |
+
"""Evaluate outputs of type MASK, and account for any class imbalance."""
|
168 |
+
mask = (truth != specs.OutputClass.MASKED).astype(np.float32)
|
169 |
+
|
170 |
+
# Use F1 score for the masked outputs to address any imbalance
|
171 |
+
tp = np.sum((((pred > 0.5) * (truth > 0.5)) * 1.0) * mask)
|
172 |
+
fp = np.sum((((pred > 0.5) * (truth < 0.5)) * 1.0) * mask)
|
173 |
+
fn = np.sum((((pred < 0.5) * (truth > 0.5)) * 1.0) * mask)
|
174 |
+
|
175 |
+
# Protect against division by zero
|
176 |
+
if tp + fp > 0:
|
177 |
+
precision = tp / (tp + fp)
|
178 |
+
else:
|
179 |
+
precision = np.float32(1.0)
|
180 |
+
if tp + fn > 0:
|
181 |
+
recall = tp / (tp + fn)
|
182 |
+
else:
|
183 |
+
recall = np.float32(1.0)
|
184 |
+
|
185 |
+
if precision + recall > 0.0:
|
186 |
+
f_1 = 2.0 * precision * recall / (precision + recall)
|
187 |
+
else:
|
188 |
+
f_1 = np.float32(0.0)
|
189 |
+
|
190 |
+
return f_1
|
191 |
+
|
192 |
+
_EVAL_FN = {
|
193 |
+
specs.Type.SCALAR:
|
194 |
+
lambda pred, truth: np.mean((pred - truth)**2),
|
195 |
+
specs.Type.MASK: _mask_fn,
|
196 |
+
specs.Type.MASK_ONE:
|
197 |
+
_eval_one,
|
198 |
+
specs.Type.CATEGORICAL:
|
199 |
+
_eval_one,
|
200 |
+
specs.Type.POINTER:
|
201 |
+
lambda pred, truth: np.mean((pred == truth) * 1.0),
|
202 |
+
}
|
benchmarks/CLRS/env/evaluation_test.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Unit tests for `evaluation.py`."""
|
17 |
+
|
18 |
+
from absl.testing import absltest
|
19 |
+
from clrs._src import evaluation
|
20 |
+
from clrs._src import probing
|
21 |
+
from clrs._src import specs
|
22 |
+
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
|
28 |
+
class EvaluationTest(absltest.TestCase):
|
29 |
+
|
30 |
+
def test_reduce_permutations(self):
|
31 |
+
b = 8
|
32 |
+
n = 16
|
33 |
+
pred = jnp.stack([jax.random.permutation(jax.random.PRNGKey(i), n)
|
34 |
+
for i in range(b)])
|
35 |
+
heads = jax.random.randint(jax.random.PRNGKey(42), (b,), 0, n)
|
36 |
+
|
37 |
+
perm = probing.DataPoint(name='test',
|
38 |
+
type_=specs.Type.PERMUTATION_POINTER,
|
39 |
+
location=specs.Location.NODE,
|
40 |
+
data=jax.nn.one_hot(pred, n))
|
41 |
+
mask = probing.DataPoint(name='test_mask',
|
42 |
+
type_=specs.Type.MASK_ONE,
|
43 |
+
location=specs.Location.NODE,
|
44 |
+
data=jax.nn.one_hot(heads, n))
|
45 |
+
output = evaluation.fuse_perm_and_mask(perm=perm, mask=mask)
|
46 |
+
expected_output = np.array(pred)
|
47 |
+
expected_output[np.arange(b), heads] = heads
|
48 |
+
self.assertEqual(output.name, 'test')
|
49 |
+
self.assertEqual(output.type_, specs.Type.POINTER)
|
50 |
+
self.assertEqual(output.location, specs.Location.NODE)
|
51 |
+
np.testing.assert_allclose(output.data, expected_output)
|
52 |
+
|
53 |
+
|
54 |
+
if __name__ == '__main__':
|
55 |
+
absltest.main()
|
benchmarks/CLRS/env/losses.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
"""Utilities for calculating losses."""
|
16 |
+
|
17 |
+
from typing import Dict, List, Tuple
|
18 |
+
import chex
|
19 |
+
from clrs._src import probing
|
20 |
+
from clrs._src import specs
|
21 |
+
|
22 |
+
import haiku as hk
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
|
26 |
+
_Array = chex.Array
|
27 |
+
_DataPoint = probing.DataPoint
|
28 |
+
_Location = specs.Location
|
29 |
+
_OutputClass = specs.OutputClass
|
30 |
+
_PredTrajectory = Dict[str, _Array]
|
31 |
+
_PredTrajectories = List[_PredTrajectory]
|
32 |
+
_Type = specs.Type
|
33 |
+
|
34 |
+
EPS = 1e-12
|
35 |
+
|
36 |
+
|
37 |
+
def _expand_to(x: _Array, y: _Array) -> _Array:
|
38 |
+
while len(y.shape) > len(x.shape):
|
39 |
+
x = jnp.expand_dims(x, -1)
|
40 |
+
return x
|
41 |
+
|
42 |
+
|
43 |
+
def _expand_and_broadcast_to(x: _Array, y: _Array) -> _Array:
|
44 |
+
return jnp.broadcast_to(_expand_to(x, y), y.shape)
|
45 |
+
|
46 |
+
|
47 |
+
def output_loss_chunked(truth: _DataPoint, pred: _Array,
|
48 |
+
is_last: _Array, nb_nodes: int) -> float:
|
49 |
+
"""Output loss for time-chunked training."""
|
50 |
+
|
51 |
+
mask = None
|
52 |
+
|
53 |
+
if truth.type_ == _Type.SCALAR:
|
54 |
+
loss = (pred - truth.data)**2
|
55 |
+
|
56 |
+
elif truth.type_ == _Type.MASK:
|
57 |
+
loss = (
|
58 |
+
jnp.maximum(pred, 0) - pred * truth.data +
|
59 |
+
jnp.log1p(jnp.exp(-jnp.abs(pred))))
|
60 |
+
mask = (truth.data != _OutputClass.MASKED)
|
61 |
+
|
62 |
+
elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]:
|
63 |
+
mask = jnp.any(truth.data == _OutputClass.POSITIVE, axis=-1)
|
64 |
+
masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype(
|
65 |
+
jnp.float32)
|
66 |
+
loss = -jnp.sum(masked_truth * jax.nn.log_softmax(pred), axis=-1)
|
67 |
+
|
68 |
+
elif truth.type_ == _Type.POINTER:
|
69 |
+
loss = -jnp.sum(
|
70 |
+
hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred), axis=-1)
|
71 |
+
|
72 |
+
elif truth.type_ == _Type.PERMUTATION_POINTER:
|
73 |
+
# Predictions are NxN logits aiming to represent a doubly stochastic matrix.
|
74 |
+
# Compute the cross entropy between doubly stochastic pred and truth_data
|
75 |
+
loss = -jnp.sum(truth.data * pred, axis=-1)
|
76 |
+
|
77 |
+
if mask is not None:
|
78 |
+
mask = mask * _expand_and_broadcast_to(is_last, loss)
|
79 |
+
else:
|
80 |
+
mask = _expand_and_broadcast_to(is_last, loss)
|
81 |
+
total_mask = jnp.maximum(jnp.sum(mask), EPS)
|
82 |
+
return jnp.sum(jnp.where(mask, loss, 0.0)) / total_mask
|
83 |
+
|
84 |
+
|
85 |
+
def output_loss(truth: _DataPoint, pred: _Array, nb_nodes: int) -> float:
|
86 |
+
"""Output loss for full-sample training."""
|
87 |
+
|
88 |
+
if truth.type_ == _Type.SCALAR:
|
89 |
+
total_loss = jnp.mean((pred - truth.data)**2)
|
90 |
+
|
91 |
+
elif truth.type_ == _Type.MASK:
|
92 |
+
loss = (
|
93 |
+
jnp.maximum(pred, 0) - pred * truth.data +
|
94 |
+
jnp.log1p(jnp.exp(-jnp.abs(pred))))
|
95 |
+
mask = (truth.data != _OutputClass.MASKED).astype(jnp.float32)
|
96 |
+
total_loss = jnp.sum(loss * mask) / jnp.sum(mask)
|
97 |
+
|
98 |
+
elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]:
|
99 |
+
masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype(
|
100 |
+
jnp.float32)
|
101 |
+
total_loss = (-jnp.sum(masked_truth * jax.nn.log_softmax(pred)) /
|
102 |
+
jnp.sum(truth.data == _OutputClass.POSITIVE))
|
103 |
+
|
104 |
+
elif truth.type_ == _Type.POINTER:
|
105 |
+
total_loss = (
|
106 |
+
jnp.mean(-jnp.sum(
|
107 |
+
hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred),
|
108 |
+
axis=-1)))
|
109 |
+
|
110 |
+
elif truth.type_ == _Type.PERMUTATION_POINTER:
|
111 |
+
# Predictions are NxN logits aiming to represent a doubly stochastic matrix.
|
112 |
+
# Compute the cross entropy between doubly stochastic pred and truth_data
|
113 |
+
total_loss = jnp.mean(-jnp.sum(truth.data * pred, axis=-1))
|
114 |
+
|
115 |
+
return total_loss
|
116 |
+
|
117 |
+
|
118 |
+
def hint_loss_chunked(
|
119 |
+
truth: _DataPoint,
|
120 |
+
pred: _Array,
|
121 |
+
is_first: _Array,
|
122 |
+
nb_nodes: int,
|
123 |
+
):
|
124 |
+
"""Hint loss for time-chunked training."""
|
125 |
+
loss, mask = _hint_loss(
|
126 |
+
truth_data=truth.data,
|
127 |
+
truth_type=truth.type_,
|
128 |
+
pred=pred,
|
129 |
+
nb_nodes=nb_nodes,
|
130 |
+
)
|
131 |
+
|
132 |
+
mask *= (1 - _expand_to(is_first, loss)).astype(jnp.float32)
|
133 |
+
loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS)
|
134 |
+
return loss
|
135 |
+
|
136 |
+
|
137 |
+
def hint_loss(
|
138 |
+
truth: _DataPoint,
|
139 |
+
preds: List[_Array],
|
140 |
+
lengths: _Array,
|
141 |
+
nb_nodes: int,
|
142 |
+
verbose: bool = False,
|
143 |
+
):
|
144 |
+
"""Hint loss for full-sample training."""
|
145 |
+
total_loss = 0.
|
146 |
+
verbose_loss = {}
|
147 |
+
length = truth.data.shape[0] - 1
|
148 |
+
|
149 |
+
loss, mask = _hint_loss(
|
150 |
+
truth_data=truth.data[1:],
|
151 |
+
truth_type=truth.type_,
|
152 |
+
pred=jnp.stack(preds),
|
153 |
+
nb_nodes=nb_nodes,
|
154 |
+
)
|
155 |
+
mask *= _is_not_done_broadcast(lengths, jnp.arange(length)[:, None], loss)
|
156 |
+
loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS)
|
157 |
+
if verbose:
|
158 |
+
verbose_loss['loss_' + truth.name] = loss
|
159 |
+
else:
|
160 |
+
total_loss += loss
|
161 |
+
|
162 |
+
return verbose_loss if verbose else total_loss
|
163 |
+
|
164 |
+
|
165 |
+
def _hint_loss(
|
166 |
+
truth_data: _Array,
|
167 |
+
truth_type: str,
|
168 |
+
pred: _Array,
|
169 |
+
nb_nodes: int,
|
170 |
+
) -> Tuple[_Array, _Array]:
|
171 |
+
"""Hint loss helper."""
|
172 |
+
mask = None
|
173 |
+
if truth_type == _Type.SCALAR:
|
174 |
+
loss = (pred - truth_data)**2
|
175 |
+
|
176 |
+
elif truth_type == _Type.MASK:
|
177 |
+
loss = (jnp.maximum(pred, 0) - pred * truth_data +
|
178 |
+
jnp.log1p(jnp.exp(-jnp.abs(pred))))
|
179 |
+
mask = (truth_data != _OutputClass.MASKED).astype(jnp.float32) # pytype: disable=attribute-error # numpy-scalars
|
180 |
+
|
181 |
+
elif truth_type == _Type.MASK_ONE:
|
182 |
+
loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1,
|
183 |
+
keepdims=True)
|
184 |
+
|
185 |
+
elif truth_type == _Type.CATEGORICAL:
|
186 |
+
loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1)
|
187 |
+
mask = jnp.any(truth_data == _OutputClass.POSITIVE, axis=-1).astype(
|
188 |
+
jnp.float32)
|
189 |
+
|
190 |
+
elif truth_type == _Type.POINTER:
|
191 |
+
loss = -jnp.sum(
|
192 |
+
hk.one_hot(truth_data, nb_nodes) * jax.nn.log_softmax(pred),
|
193 |
+
axis=-1)
|
194 |
+
|
195 |
+
elif truth_type == _Type.PERMUTATION_POINTER:
|
196 |
+
# Predictions are NxN logits aiming to represent a doubly stochastic matrix.
|
197 |
+
# Compute the cross entropy between doubly stochastic pred and truth_data
|
198 |
+
loss = -jnp.sum(truth_data * pred, axis=-1)
|
199 |
+
|
200 |
+
if mask is None:
|
201 |
+
mask = jnp.ones_like(loss)
|
202 |
+
return loss, mask
|
203 |
+
|
204 |
+
|
205 |
+
def _is_not_done_broadcast(lengths, i, tensor):
|
206 |
+
is_not_done = (lengths > i + 1) * 1.0
|
207 |
+
while len(is_not_done.shape) < len(tensor.shape): # pytype: disable=attribute-error # numpy-scalars
|
208 |
+
is_not_done = jnp.expand_dims(is_not_done, -1)
|
209 |
+
return is_not_done
|
benchmarks/CLRS/env/losses_test.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Unit tests for `losses.py`."""
|
17 |
+
|
18 |
+
from typing import Generator
|
19 |
+
|
20 |
+
from absl.testing import absltest
|
21 |
+
from absl.testing import parameterized
|
22 |
+
|
23 |
+
from clrs._src import dataset
|
24 |
+
from clrs._src import losses
|
25 |
+
from clrs._src import probing
|
26 |
+
from clrs._src import samplers
|
27 |
+
from clrs._src import specs
|
28 |
+
import jax
|
29 |
+
import jax.numpy as jnp
|
30 |
+
import numpy as np
|
31 |
+
|
32 |
+
_Array = np.ndarray
|
33 |
+
_Location = specs.Location
|
34 |
+
|
35 |
+
|
36 |
+
def _make_sampler(algo: str, nb_nodes: int) -> samplers.Sampler:
|
37 |
+
sampler, _ = samplers.build_sampler(
|
38 |
+
algo,
|
39 |
+
seed=samplers.CLRS30['val']['seed'],
|
40 |
+
num_samples=samplers.CLRS30['val']['num_samples'],
|
41 |
+
length=nb_nodes,
|
42 |
+
)
|
43 |
+
return sampler
|
44 |
+
|
45 |
+
|
46 |
+
def _make_iterable_sampler(
|
47 |
+
algo: str, batch_size: int,
|
48 |
+
nb_nodes: int) -> Generator[samplers.Feedback, None, None]:
|
49 |
+
sampler = _make_sampler(algo, nb_nodes)
|
50 |
+
while True:
|
51 |
+
yield sampler.next(batch_size)
|
52 |
+
|
53 |
+
|
54 |
+
def _as_pred_data(x, nb_nodes, seed, batch_axis):
|
55 |
+
"""Fake a prediction from a data point."""
|
56 |
+
# Permute along batch axis to make the prediction different.
|
57 |
+
key = jax.random.PRNGKey(seed)
|
58 |
+
data = jax.random.permutation(key, x.data, axis=batch_axis)
|
59 |
+
# Extend to one-hot for pointer types.
|
60 |
+
if x.type_ == specs.Type.POINTER:
|
61 |
+
return jax.nn.one_hot(data, nb_nodes)
|
62 |
+
return data
|
63 |
+
|
64 |
+
|
65 |
+
def _mask_datapoint(x, seed, t_axis=None):
|
66 |
+
"""Add some masking to data."""
|
67 |
+
key = jax.random.PRNGKey(seed)
|
68 |
+
data = x.data
|
69 |
+
if x.type_ == specs.Type.MASK:
|
70 |
+
# mask some data at random
|
71 |
+
mask_shape = list(data.shape)
|
72 |
+
if t_axis is not None:
|
73 |
+
mask_shape[t_axis] = 1
|
74 |
+
mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2
|
75 |
+
data = jnp.where(mask, specs.OutputClass.MASKED, data)
|
76 |
+
elif x.type_ in [specs.Type.CATEGORICAL, specs.Type.MASK_ONE]:
|
77 |
+
# mask some data at random (all categories together)
|
78 |
+
mask_shape = list(data.shape)[:-1]
|
79 |
+
if t_axis is not None:
|
80 |
+
mask_shape[t_axis] = 1
|
81 |
+
mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2
|
82 |
+
data = jnp.where(mask[..., None], specs.OutputClass.MASKED, data)
|
83 |
+
return probing.DataPoint(name=x.name, location=x.location, type_=x.type_,
|
84 |
+
data=data)
|
85 |
+
|
86 |
+
|
87 |
+
def _rand_diff(seed, shape):
|
88 |
+
return 2.0 * jax.random.uniform(jax.random.PRNGKey(seed), shape) - 1.0
|
89 |
+
|
90 |
+
|
91 |
+
def _rand_mask(seed, shape, p=0.5):
|
92 |
+
return (jax.random.uniform(jax.random.PRNGKey(seed), shape) > p).astype(float)
|
93 |
+
|
94 |
+
|
95 |
+
def invert(d):
|
96 |
+
"""Dict of lists -> list of dicts."""
|
97 |
+
if d:
|
98 |
+
return [dict(zip(d, i)) for i in zip(*d.values())]
|
99 |
+
|
100 |
+
|
101 |
+
def _create_data(algo, nb_nodes):
|
102 |
+
batch_size = 8
|
103 |
+
|
104 |
+
ds = _make_iterable_sampler(algo, batch_size, nb_nodes)
|
105 |
+
full_sample = next(ds)
|
106 |
+
|
107 |
+
chunk_length = full_sample.features.lengths[0].astype(int)
|
108 |
+
chunked_ds = dataset.chunkify(
|
109 |
+
_make_iterable_sampler(algo, batch_size, nb_nodes),
|
110 |
+
chunk_length)
|
111 |
+
chunk_sample = next(chunked_ds)
|
112 |
+
return full_sample, chunk_sample
|
113 |
+
|
114 |
+
|
115 |
+
class FullVsChunkLossesTest(parameterized.TestCase):
|
116 |
+
"""Test that the full and chunked versions of the losses match."""
|
117 |
+
|
118 |
+
# Test two algorithms with fixed-length, covering all data types
|
119 |
+
@parameterized.parameters('dfs', 'floyd_warshall')
|
120 |
+
def test_output_loss(self, algo):
|
121 |
+
nb_nodes = 16
|
122 |
+
full_sample, chunk_sample = _create_data(algo, nb_nodes)
|
123 |
+
|
124 |
+
# Calculate output loss.
|
125 |
+
for truth_full, truth_chunked in zip(full_sample.outputs,
|
126 |
+
chunk_sample.outputs):
|
127 |
+
chunk_output_loss = losses.output_loss_chunked(
|
128 |
+
truth=_mask_datapoint(truth_chunked, seed=0),
|
129 |
+
pred=_as_pred_data(truth_chunked, nb_nodes, 0, 1),
|
130 |
+
is_last=chunk_sample.features.is_last,
|
131 |
+
nb_nodes=nb_nodes,
|
132 |
+
)
|
133 |
+
full_output_loss = losses.output_loss(
|
134 |
+
truth=_mask_datapoint(truth_full, seed=0),
|
135 |
+
pred=_as_pred_data(truth_full, nb_nodes, 0, 0),
|
136 |
+
nb_nodes=nb_nodes,
|
137 |
+
)
|
138 |
+
np.testing.assert_allclose(chunk_output_loss, full_output_loss, rtol=1e-4)
|
139 |
+
|
140 |
+
@parameterized.parameters('dfs', 'floyd_warshall')
|
141 |
+
def test_hint_loss(self, algo):
|
142 |
+
nb_nodes = 16
|
143 |
+
full_sample, chunk_sample = _create_data(algo, nb_nodes)
|
144 |
+
for truth_full, truth_chunked in zip(full_sample.features.hints,
|
145 |
+
chunk_sample.features.hints):
|
146 |
+
np.testing.assert_array_equal(truth_full.data, truth_chunked.data)
|
147 |
+
pred = _as_pred_data(truth_chunked, nb_nodes, 0, 1)
|
148 |
+
chunk_hint_loss = losses.hint_loss_chunked(
|
149 |
+
truth=_mask_datapoint(truth_chunked, seed=1, t_axis=0),
|
150 |
+
pred=pred,
|
151 |
+
is_first=chunk_sample.features.is_first,
|
152 |
+
nb_nodes=nb_nodes,
|
153 |
+
)
|
154 |
+
|
155 |
+
full_preds = pred[1:]
|
156 |
+
full_hint_loss = losses.hint_loss(
|
157 |
+
truth=_mask_datapoint(truth_full, 1, t_axis=0),
|
158 |
+
preds=full_preds,
|
159 |
+
lengths=full_sample.features.lengths,
|
160 |
+
nb_nodes=nb_nodes,
|
161 |
+
)
|
162 |
+
np.testing.assert_allclose(chunk_hint_loss, full_hint_loss, rtol=1e-4)
|
163 |
+
|
164 |
+
|
165 |
+
if __name__ == '__main__':
|
166 |
+
absltest.main()
|
benchmarks/CLRS/env/model.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Model base classes and utilities."""
|
17 |
+
|
18 |
+
import abc
|
19 |
+
from typing import Dict, List, Optional, Union
|
20 |
+
|
21 |
+
from clrs._src import probing
|
22 |
+
from clrs._src import samplers
|
23 |
+
from clrs._src import specs
|
24 |
+
|
25 |
+
|
26 |
+
Result = Dict[str, probing.DataPoint]
|
27 |
+
|
28 |
+
|
29 |
+
class Model(abc.ABC):
|
30 |
+
"""Abstract base class for CLRS3-B models."""
|
31 |
+
|
32 |
+
def __init__(self, spec: Union[specs.Spec, List[specs.Spec]]):
|
33 |
+
"""Set up the problem, prepare to predict on first task."""
|
34 |
+
if not isinstance(spec, list):
|
35 |
+
spec = [spec]
|
36 |
+
self._spec = spec
|
37 |
+
|
38 |
+
@abc.abstractmethod
|
39 |
+
def predict(self, features: samplers.Features) -> Result:
|
40 |
+
"""Make predictions about the current task."""
|
41 |
+
pass
|
42 |
+
|
43 |
+
@abc.abstractmethod
|
44 |
+
def feedback(self, feedback: Optional[samplers.Feedback]):
|
45 |
+
"""Advance to the next task, incorporating any available feedback."""
|
46 |
+
pass
|
benchmarks/CLRS/env/nets.py
ADDED
@@ -0,0 +1,719 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""JAX implementation of CLRS basic network."""
|
17 |
+
|
18 |
+
import functools
|
19 |
+
|
20 |
+
from typing import Dict, List, Optional, Tuple
|
21 |
+
|
22 |
+
import chex
|
23 |
+
|
24 |
+
from clrs._src import decoders
|
25 |
+
from clrs._src import encoders
|
26 |
+
from clrs._src import probing
|
27 |
+
from clrs._src import processors
|
28 |
+
from clrs._src import samplers
|
29 |
+
from clrs._src import specs
|
30 |
+
|
31 |
+
import haiku as hk
|
32 |
+
import jax
|
33 |
+
import jax.numpy as jnp
|
34 |
+
|
35 |
+
|
36 |
+
_Array = chex.Array
|
37 |
+
_DataPoint = probing.DataPoint
|
38 |
+
_Features = samplers.Features
|
39 |
+
_FeaturesChunked = samplers.FeaturesChunked
|
40 |
+
_Location = specs.Location
|
41 |
+
_Spec = specs.Spec
|
42 |
+
_Stage = specs.Stage
|
43 |
+
_Trajectory = samplers.Trajectory
|
44 |
+
_Type = specs.Type
|
45 |
+
|
46 |
+
|
47 |
+
@chex.dataclass
|
48 |
+
class _MessagePassingScanState:
|
49 |
+
hint_preds: chex.Array
|
50 |
+
output_preds: chex.Array
|
51 |
+
hiddens: chex.Array
|
52 |
+
lstm_state: Optional[hk.LSTMState]
|
53 |
+
|
54 |
+
|
55 |
+
@chex.dataclass
|
56 |
+
class _MessagePassingOutputChunked:
|
57 |
+
hint_preds: chex.Array
|
58 |
+
output_preds: chex.Array
|
59 |
+
|
60 |
+
|
61 |
+
@chex.dataclass
|
62 |
+
class MessagePassingStateChunked:
|
63 |
+
inputs: chex.Array
|
64 |
+
hints: chex.Array
|
65 |
+
is_first: chex.Array
|
66 |
+
hint_preds: chex.Array
|
67 |
+
hiddens: chex.Array
|
68 |
+
lstm_state: Optional[hk.LSTMState]
|
69 |
+
|
70 |
+
|
71 |
+
class Net(hk.Module):
|
72 |
+
"""Building blocks (networks) used to encode and decode messages."""
|
73 |
+
|
74 |
+
def __init__(
|
75 |
+
self,
|
76 |
+
spec: List[_Spec],
|
77 |
+
hidden_dim: int,
|
78 |
+
encode_hints: bool,
|
79 |
+
decode_hints: bool,
|
80 |
+
processor_factory: processors.ProcessorFactory,
|
81 |
+
use_lstm: bool,
|
82 |
+
encoder_init: str,
|
83 |
+
dropout_prob: float,
|
84 |
+
hint_teacher_forcing: float,
|
85 |
+
hint_repred_mode='soft',
|
86 |
+
nb_dims=None,
|
87 |
+
nb_msg_passing_steps=1,
|
88 |
+
name: str = 'net',
|
89 |
+
):
|
90 |
+
"""Constructs a `Net`."""
|
91 |
+
super().__init__(name=name)
|
92 |
+
|
93 |
+
self._dropout_prob = dropout_prob
|
94 |
+
self._hint_teacher_forcing = hint_teacher_forcing
|
95 |
+
self._hint_repred_mode = hint_repred_mode
|
96 |
+
self.spec = spec
|
97 |
+
self.hidden_dim = hidden_dim
|
98 |
+
self.encode_hints = encode_hints
|
99 |
+
self.decode_hints = decode_hints
|
100 |
+
self.processor_factory = processor_factory
|
101 |
+
self.nb_dims = nb_dims
|
102 |
+
self.use_lstm = use_lstm
|
103 |
+
self.encoder_init = encoder_init
|
104 |
+
self.nb_msg_passing_steps = nb_msg_passing_steps
|
105 |
+
|
106 |
+
def _msg_passing_step(self,
|
107 |
+
mp_state: _MessagePassingScanState,
|
108 |
+
i: int,
|
109 |
+
hints: List[_DataPoint],
|
110 |
+
repred: bool,
|
111 |
+
lengths: chex.Array,
|
112 |
+
batch_size: int,
|
113 |
+
nb_nodes: int,
|
114 |
+
inputs: _Trajectory,
|
115 |
+
first_step: bool,
|
116 |
+
spec: _Spec,
|
117 |
+
encs: Dict[str, List[hk.Module]],
|
118 |
+
decs: Dict[str, Tuple[hk.Module]],
|
119 |
+
return_hints: bool,
|
120 |
+
return_all_outputs: bool
|
121 |
+
):
|
122 |
+
if self.decode_hints and not first_step:
|
123 |
+
assert self._hint_repred_mode in ['soft', 'hard', 'hard_on_eval']
|
124 |
+
hard_postprocess = (self._hint_repred_mode == 'hard' or
|
125 |
+
(self._hint_repred_mode == 'hard_on_eval' and repred))
|
126 |
+
decoded_hint = decoders.postprocess(spec,
|
127 |
+
mp_state.hint_preds,
|
128 |
+
sinkhorn_temperature=0.1,
|
129 |
+
sinkhorn_steps=25,
|
130 |
+
hard=hard_postprocess)
|
131 |
+
if repred and self.decode_hints and not first_step:
|
132 |
+
cur_hint = []
|
133 |
+
for hint in decoded_hint:
|
134 |
+
cur_hint.append(decoded_hint[hint])
|
135 |
+
else:
|
136 |
+
cur_hint = []
|
137 |
+
needs_noise = (self.decode_hints and not first_step and
|
138 |
+
self._hint_teacher_forcing < 1.0)
|
139 |
+
if needs_noise:
|
140 |
+
# For noisy teacher forcing, choose which examples in the batch to force
|
141 |
+
force_mask = jax.random.bernoulli(
|
142 |
+
hk.next_rng_key(), self._hint_teacher_forcing,
|
143 |
+
(batch_size,))
|
144 |
+
else:
|
145 |
+
force_mask = None
|
146 |
+
for hint in hints:
|
147 |
+
hint_data = jnp.asarray(hint.data)[i]
|
148 |
+
_, loc, typ = spec[hint.name]
|
149 |
+
if needs_noise:
|
150 |
+
if (typ == _Type.POINTER and
|
151 |
+
decoded_hint[hint.name].type_ == _Type.SOFT_POINTER):
|
152 |
+
# When using soft pointers, the decoded hints cannot be summarised
|
153 |
+
# as indices (as would happen in hard postprocessing), so we need
|
154 |
+
# to raise the ground-truth hint (potentially used for teacher
|
155 |
+
# forcing) to its one-hot version.
|
156 |
+
hint_data = hk.one_hot(hint_data, nb_nodes)
|
157 |
+
typ = _Type.SOFT_POINTER
|
158 |
+
hint_data = jnp.where(_expand_to(force_mask, hint_data),
|
159 |
+
hint_data,
|
160 |
+
decoded_hint[hint.name].data)
|
161 |
+
cur_hint.append(
|
162 |
+
probing.DataPoint(
|
163 |
+
name=hint.name, location=loc, type_=typ, data=hint_data))
|
164 |
+
|
165 |
+
hiddens, output_preds_cand, hint_preds, lstm_state = self._one_step_pred(
|
166 |
+
inputs, cur_hint, mp_state.hiddens,
|
167 |
+
batch_size, nb_nodes, mp_state.lstm_state,
|
168 |
+
spec, encs, decs, repred)
|
169 |
+
|
170 |
+
if first_step:
|
171 |
+
output_preds = output_preds_cand
|
172 |
+
else:
|
173 |
+
output_preds = {}
|
174 |
+
for outp in mp_state.output_preds:
|
175 |
+
is_not_done = _is_not_done_broadcast(lengths, i,
|
176 |
+
output_preds_cand[outp])
|
177 |
+
output_preds[outp] = is_not_done * output_preds_cand[outp] + (
|
178 |
+
1.0 - is_not_done) * mp_state.output_preds[outp]
|
179 |
+
|
180 |
+
new_mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars
|
181 |
+
hint_preds=hint_preds,
|
182 |
+
output_preds=output_preds,
|
183 |
+
hiddens=hiddens,
|
184 |
+
lstm_state=lstm_state)
|
185 |
+
# Save memory by not stacking unnecessary fields
|
186 |
+
accum_mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars
|
187 |
+
hint_preds=hint_preds if return_hints else None,
|
188 |
+
output_preds=output_preds if return_all_outputs else None,
|
189 |
+
hiddens=None, lstm_state=None)
|
190 |
+
|
191 |
+
# Complying to jax.scan, the first returned value is the state we carry over
|
192 |
+
# the second value is the output that will be stacked over steps.
|
193 |
+
return new_mp_state, accum_mp_state
|
194 |
+
|
195 |
+
def __call__(self, features_list: List[_Features], repred: bool,
|
196 |
+
algorithm_index: int,
|
197 |
+
return_hints: bool,
|
198 |
+
return_all_outputs: bool):
|
199 |
+
"""Process one batch of data.
|
200 |
+
|
201 |
+
Args:
|
202 |
+
features_list: A list of _Features objects, each with the inputs, hints
|
203 |
+
and lengths for a batch o data corresponding to one algorithm.
|
204 |
+
The list should have either length 1, at train/evaluation time,
|
205 |
+
or length equal to the number of algorithms this Net is meant to
|
206 |
+
process, at initialization.
|
207 |
+
repred: False during training, when we have access to ground-truth hints.
|
208 |
+
True in validation/test mode, when we have to use our own
|
209 |
+
hint predictions.
|
210 |
+
algorithm_index: Which algorithm is being processed. It can be -1 at
|
211 |
+
initialisation (either because we are initialising the parameters of
|
212 |
+
the module or because we are intialising the message-passing state),
|
213 |
+
meaning that all algorithms should be processed, in which case
|
214 |
+
`features_list` should have length equal to the number of specs of
|
215 |
+
the Net. Otherwise, `algorithm_index` should be
|
216 |
+
between 0 and `length(self.spec) - 1`, meaning only one of the
|
217 |
+
algorithms will be processed, and `features_list` should have length 1.
|
218 |
+
return_hints: Whether to accumulate and return the predicted hints,
|
219 |
+
when they are decoded.
|
220 |
+
return_all_outputs: Whether to return the full sequence of outputs, or
|
221 |
+
just the last step's output.
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
A 2-tuple with (output predictions, hint predictions)
|
225 |
+
for the selected algorithm.
|
226 |
+
"""
|
227 |
+
if algorithm_index == -1:
|
228 |
+
algorithm_indices = range(len(features_list))
|
229 |
+
else:
|
230 |
+
algorithm_indices = [algorithm_index]
|
231 |
+
assert len(algorithm_indices) == len(features_list)
|
232 |
+
|
233 |
+
self.encoders, self.decoders = self._construct_encoders_decoders()
|
234 |
+
self.processor = self.processor_factory(self.hidden_dim)
|
235 |
+
|
236 |
+
# Optionally construct LSTM.
|
237 |
+
if self.use_lstm:
|
238 |
+
self.lstm = hk.LSTM(
|
239 |
+
hidden_size=self.hidden_dim,
|
240 |
+
name='processor_lstm')
|
241 |
+
lstm_init = self.lstm.initial_state
|
242 |
+
else:
|
243 |
+
self.lstm = None
|
244 |
+
lstm_init = lambda x: 0
|
245 |
+
|
246 |
+
for algorithm_index, features in zip(algorithm_indices, features_list):
|
247 |
+
inputs = features.inputs
|
248 |
+
hints = features.hints
|
249 |
+
lengths = features.lengths
|
250 |
+
|
251 |
+
batch_size, nb_nodes = _data_dimensions(features)
|
252 |
+
|
253 |
+
nb_mp_steps = max(1, hints[0].data.shape[0] - 1)
|
254 |
+
hiddens = jnp.zeros((batch_size, nb_nodes, self.hidden_dim))
|
255 |
+
|
256 |
+
if self.use_lstm:
|
257 |
+
lstm_state = lstm_init(batch_size * nb_nodes)
|
258 |
+
lstm_state = jax.tree_util.tree_map(
|
259 |
+
lambda x, b=batch_size, n=nb_nodes: jnp.reshape(x, [b, n, -1]),
|
260 |
+
lstm_state)
|
261 |
+
else:
|
262 |
+
lstm_state = None
|
263 |
+
|
264 |
+
mp_state = _MessagePassingScanState( # pytype: disable=wrong-arg-types # numpy-scalars
|
265 |
+
hint_preds=None, output_preds=None,
|
266 |
+
hiddens=hiddens, lstm_state=lstm_state)
|
267 |
+
|
268 |
+
# Do the first step outside of the scan because it has a different
|
269 |
+
# computation graph.
|
270 |
+
common_args = dict(
|
271 |
+
hints=hints,
|
272 |
+
repred=repred,
|
273 |
+
inputs=inputs,
|
274 |
+
batch_size=batch_size,
|
275 |
+
nb_nodes=nb_nodes,
|
276 |
+
lengths=lengths,
|
277 |
+
spec=self.spec[algorithm_index],
|
278 |
+
encs=self.encoders[algorithm_index],
|
279 |
+
decs=self.decoders[algorithm_index],
|
280 |
+
return_hints=return_hints,
|
281 |
+
return_all_outputs=return_all_outputs,
|
282 |
+
)
|
283 |
+
mp_state, lean_mp_state = self._msg_passing_step(
|
284 |
+
mp_state,
|
285 |
+
i=0,
|
286 |
+
first_step=True,
|
287 |
+
**common_args)
|
288 |
+
|
289 |
+
# Then scan through the rest.
|
290 |
+
scan_fn = functools.partial(
|
291 |
+
self._msg_passing_step,
|
292 |
+
first_step=False,
|
293 |
+
**common_args)
|
294 |
+
|
295 |
+
output_mp_state, accum_mp_state = hk.scan(
|
296 |
+
scan_fn,
|
297 |
+
mp_state,
|
298 |
+
jnp.arange(nb_mp_steps - 1) + 1,
|
299 |
+
length=nb_mp_steps - 1)
|
300 |
+
|
301 |
+
# We only return the last algorithm's output. That's because
|
302 |
+
# the output only matters when a single algorithm is processed; the case
|
303 |
+
# `algorithm_index==-1` (meaning all algorithms should be processed)
|
304 |
+
# is used only to init parameters.
|
305 |
+
accum_mp_state = jax.tree_util.tree_map(
|
306 |
+
lambda init, tail: jnp.concatenate([init[None], tail], axis=0),
|
307 |
+
lean_mp_state, accum_mp_state)
|
308 |
+
|
309 |
+
def invert(d):
|
310 |
+
"""Dict of lists -> list of dicts."""
|
311 |
+
if d:
|
312 |
+
return [dict(zip(d, i)) for i in zip(*d.values())]
|
313 |
+
|
314 |
+
if return_all_outputs:
|
315 |
+
output_preds = {k: jnp.stack(v)
|
316 |
+
for k, v in accum_mp_state.output_preds.items()}
|
317 |
+
else:
|
318 |
+
output_preds = output_mp_state.output_preds
|
319 |
+
hint_preds = invert(accum_mp_state.hint_preds)
|
320 |
+
|
321 |
+
return output_preds, hint_preds
|
322 |
+
|
323 |
+
def _construct_encoders_decoders(self):
|
324 |
+
"""Constructs encoders and decoders, separate for each algorithm."""
|
325 |
+
encoders_ = []
|
326 |
+
decoders_ = []
|
327 |
+
enc_algo_idx = None
|
328 |
+
for (algo_idx, spec) in enumerate(self.spec):
|
329 |
+
enc = {}
|
330 |
+
dec = {}
|
331 |
+
for name, (stage, loc, t) in spec.items():
|
332 |
+
if stage == _Stage.INPUT or (
|
333 |
+
stage == _Stage.HINT and self.encode_hints):
|
334 |
+
# Build input encoders.
|
335 |
+
if name == specs.ALGO_IDX_INPUT_NAME:
|
336 |
+
if enc_algo_idx is None:
|
337 |
+
enc_algo_idx = [hk.Linear(self.hidden_dim,
|
338 |
+
name=f'{name}_enc_linear')]
|
339 |
+
enc[name] = enc_algo_idx
|
340 |
+
else:
|
341 |
+
enc[name] = encoders.construct_encoders(
|
342 |
+
stage, loc, t, hidden_dim=self.hidden_dim,
|
343 |
+
init=self.encoder_init,
|
344 |
+
name=f'algo_{algo_idx}_{name}')
|
345 |
+
|
346 |
+
if stage == _Stage.OUTPUT or (
|
347 |
+
stage == _Stage.HINT and self.decode_hints):
|
348 |
+
# Build output decoders.
|
349 |
+
dec[name] = decoders.construct_decoders(
|
350 |
+
loc, t, hidden_dim=self.hidden_dim,
|
351 |
+
nb_dims=self.nb_dims[algo_idx][name],
|
352 |
+
name=f'algo_{algo_idx}_{name}')
|
353 |
+
encoders_.append(enc)
|
354 |
+
decoders_.append(dec)
|
355 |
+
|
356 |
+
return encoders_, decoders_
|
357 |
+
|
358 |
+
def _one_step_pred(
|
359 |
+
self,
|
360 |
+
inputs: _Trajectory,
|
361 |
+
hints: _Trajectory,
|
362 |
+
hidden: _Array,
|
363 |
+
batch_size: int,
|
364 |
+
nb_nodes: int,
|
365 |
+
lstm_state: Optional[hk.LSTMState],
|
366 |
+
spec: _Spec,
|
367 |
+
encs: Dict[str, List[hk.Module]],
|
368 |
+
decs: Dict[str, Tuple[hk.Module]],
|
369 |
+
repred: bool,
|
370 |
+
):
|
371 |
+
"""Generates one-step predictions."""
|
372 |
+
|
373 |
+
# Initialise empty node/edge/graph features and adjacency matrix.
|
374 |
+
node_fts = jnp.zeros((batch_size, nb_nodes, self.hidden_dim))
|
375 |
+
edge_fts = jnp.zeros((batch_size, nb_nodes, nb_nodes, self.hidden_dim))
|
376 |
+
graph_fts = jnp.zeros((batch_size, self.hidden_dim))
|
377 |
+
adj_mat = jnp.repeat(
|
378 |
+
jnp.expand_dims(jnp.eye(nb_nodes), 0), batch_size, axis=0)
|
379 |
+
|
380 |
+
# ENCODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
381 |
+
# Encode node/edge/graph features from inputs and (optionally) hints.
|
382 |
+
trajectories = [inputs]
|
383 |
+
if self.encode_hints:
|
384 |
+
trajectories.append(hints)
|
385 |
+
|
386 |
+
for trajectory in trajectories:
|
387 |
+
for dp in trajectory:
|
388 |
+
try:
|
389 |
+
dp = encoders.preprocess(dp, nb_nodes)
|
390 |
+
assert dp.type_ != _Type.SOFT_POINTER
|
391 |
+
adj_mat = encoders.accum_adj_mat(dp, adj_mat)
|
392 |
+
encoder = encs[dp.name]
|
393 |
+
edge_fts = encoders.accum_edge_fts(encoder, dp, edge_fts)
|
394 |
+
node_fts = encoders.accum_node_fts(encoder, dp, node_fts)
|
395 |
+
graph_fts = encoders.accum_graph_fts(encoder, dp, graph_fts)
|
396 |
+
except Exception as e:
|
397 |
+
raise Exception(f'Failed to process {dp}') from e
|
398 |
+
|
399 |
+
# PROCESS ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
400 |
+
nxt_hidden = hidden
|
401 |
+
for _ in range(self.nb_msg_passing_steps):
|
402 |
+
nxt_hidden, nxt_edge = self.processor(
|
403 |
+
node_fts,
|
404 |
+
edge_fts,
|
405 |
+
graph_fts,
|
406 |
+
adj_mat,
|
407 |
+
nxt_hidden,
|
408 |
+
batch_size=batch_size,
|
409 |
+
nb_nodes=nb_nodes,
|
410 |
+
)
|
411 |
+
|
412 |
+
if not repred: # dropout only on training
|
413 |
+
nxt_hidden = hk.dropout(hk.next_rng_key(), self._dropout_prob, nxt_hidden)
|
414 |
+
|
415 |
+
if self.use_lstm:
|
416 |
+
# lstm doesn't accept multiple batch dimensions (in our case, batch and
|
417 |
+
# nodes), so we vmap over the (first) batch dimension.
|
418 |
+
nxt_hidden, nxt_lstm_state = jax.vmap(self.lstm)(nxt_hidden, lstm_state)
|
419 |
+
else:
|
420 |
+
nxt_lstm_state = None
|
421 |
+
|
422 |
+
h_t = jnp.concatenate([node_fts, hidden, nxt_hidden], axis=-1)
|
423 |
+
if nxt_edge is not None:
|
424 |
+
e_t = jnp.concatenate([edge_fts, nxt_edge], axis=-1)
|
425 |
+
else:
|
426 |
+
e_t = edge_fts
|
427 |
+
|
428 |
+
# DECODE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
429 |
+
# Decode features and (optionally) hints.
|
430 |
+
hint_preds, output_preds = decoders.decode_fts(
|
431 |
+
decoders=decs,
|
432 |
+
spec=spec,
|
433 |
+
h_t=h_t,
|
434 |
+
adj_mat=adj_mat,
|
435 |
+
edge_fts=e_t,
|
436 |
+
graph_fts=graph_fts,
|
437 |
+
inf_bias=self.processor.inf_bias,
|
438 |
+
inf_bias_edge=self.processor.inf_bias_edge,
|
439 |
+
repred=repred,
|
440 |
+
)
|
441 |
+
|
442 |
+
return nxt_hidden, output_preds, hint_preds, nxt_lstm_state
|
443 |
+
|
444 |
+
|
445 |
+
class NetChunked(Net):
|
446 |
+
"""A Net that will process time-chunked data instead of full samples."""
|
447 |
+
|
448 |
+
def _msg_passing_step(self,
|
449 |
+
mp_state: MessagePassingStateChunked,
|
450 |
+
xs,
|
451 |
+
repred: bool,
|
452 |
+
init_mp_state: bool,
|
453 |
+
batch_size: int,
|
454 |
+
nb_nodes: int,
|
455 |
+
spec: _Spec,
|
456 |
+
encs: Dict[str, List[hk.Module]],
|
457 |
+
decs: Dict[str, Tuple[hk.Module]],
|
458 |
+
):
|
459 |
+
"""Perform one message passing step.
|
460 |
+
|
461 |
+
This function is unrolled along the time axis to process a data chunk.
|
462 |
+
|
463 |
+
Args:
|
464 |
+
mp_state: message-passing state. Includes the inputs, hints,
|
465 |
+
beginning-of-sample markers, hint predictions, hidden and lstm state
|
466 |
+
to be used for prediction in the current step.
|
467 |
+
xs: A 3-tuple of with the next timestep's inputs, hints, and
|
468 |
+
beginning-of-sample markers. These will replace the contents of
|
469 |
+
the `mp_state` at the output, in readiness for the next unroll step of
|
470 |
+
the chunk (or the first step of the next chunk). Besides, the next
|
471 |
+
timestep's hints are necessary to compute diffs when `decode_diffs`
|
472 |
+
is True.
|
473 |
+
repred: False during training, when we have access to ground-truth hints.
|
474 |
+
True in validation/test mode, when we have to use our own
|
475 |
+
hint predictions.
|
476 |
+
init_mp_state: Indicates if we are calling the method just to initialise
|
477 |
+
the message-passing state, before the beginning of training or
|
478 |
+
validation.
|
479 |
+
batch_size: Size of batch dimension.
|
480 |
+
nb_nodes: Number of nodes in graph.
|
481 |
+
spec: The spec of the algorithm being processed.
|
482 |
+
encs: encoders for the algorithm being processed.
|
483 |
+
decs: decoders for the algorithm being processed.
|
484 |
+
Returns:
|
485 |
+
A 2-tuple with the next mp_state and an output consisting of
|
486 |
+
hint predictions and output predictions.
|
487 |
+
"""
|
488 |
+
def _as_prediction_data(hint):
|
489 |
+
if hint.type_ == _Type.POINTER:
|
490 |
+
return hk.one_hot(hint.data, nb_nodes)
|
491 |
+
return hint.data
|
492 |
+
|
493 |
+
nxt_inputs, nxt_hints, nxt_is_first = xs
|
494 |
+
inputs = mp_state.inputs
|
495 |
+
is_first = mp_state.is_first
|
496 |
+
hints = mp_state.hints
|
497 |
+
if init_mp_state:
|
498 |
+
prev_hint_preds = {h.name: _as_prediction_data(h) for h in hints}
|
499 |
+
hints_for_pred = hints
|
500 |
+
else:
|
501 |
+
prev_hint_preds = mp_state.hint_preds
|
502 |
+
if self.decode_hints:
|
503 |
+
if repred:
|
504 |
+
force_mask = jnp.zeros(batch_size, dtype=bool)
|
505 |
+
elif self._hint_teacher_forcing == 1.0:
|
506 |
+
force_mask = jnp.ones(batch_size, dtype=bool)
|
507 |
+
else:
|
508 |
+
force_mask = jax.random.bernoulli(
|
509 |
+
hk.next_rng_key(), self._hint_teacher_forcing,
|
510 |
+
(batch_size,))
|
511 |
+
assert self._hint_repred_mode in ['soft', 'hard', 'hard_on_eval']
|
512 |
+
hard_postprocess = (
|
513 |
+
self._hint_repred_mode == 'hard' or
|
514 |
+
(self._hint_repred_mode == 'hard_on_eval' and repred))
|
515 |
+
decoded_hints = decoders.postprocess(spec,
|
516 |
+
prev_hint_preds,
|
517 |
+
sinkhorn_temperature=0.1,
|
518 |
+
sinkhorn_steps=25,
|
519 |
+
hard=hard_postprocess)
|
520 |
+
hints_for_pred = []
|
521 |
+
for h in hints:
|
522 |
+
typ = h.type_
|
523 |
+
hint_data = h.data
|
524 |
+
if (typ == _Type.POINTER and
|
525 |
+
decoded_hints[h.name].type_ == _Type.SOFT_POINTER):
|
526 |
+
hint_data = hk.one_hot(hint_data, nb_nodes)
|
527 |
+
typ = _Type.SOFT_POINTER
|
528 |
+
hints_for_pred.append(probing.DataPoint(
|
529 |
+
name=h.name, location=h.location, type_=typ,
|
530 |
+
data=jnp.where(_expand_to(is_first | force_mask, hint_data),
|
531 |
+
hint_data, decoded_hints[h.name].data)))
|
532 |
+
else:
|
533 |
+
hints_for_pred = hints
|
534 |
+
|
535 |
+
hiddens = jnp.where(is_first[..., None, None], 0.0, mp_state.hiddens)
|
536 |
+
if self.use_lstm:
|
537 |
+
lstm_state = jax.tree_util.tree_map(
|
538 |
+
lambda x: jnp.where(is_first[..., None, None], 0.0, x),
|
539 |
+
mp_state.lstm_state)
|
540 |
+
else:
|
541 |
+
lstm_state = None
|
542 |
+
hiddens, output_preds, hint_preds, lstm_state = self._one_step_pred(
|
543 |
+
inputs, hints_for_pred, hiddens,
|
544 |
+
batch_size, nb_nodes, lstm_state,
|
545 |
+
spec, encs, decs, repred)
|
546 |
+
|
547 |
+
new_mp_state = MessagePassingStateChunked( # pytype: disable=wrong-arg-types # numpy-scalars
|
548 |
+
hiddens=hiddens, lstm_state=lstm_state, hint_preds=hint_preds,
|
549 |
+
inputs=nxt_inputs, hints=nxt_hints, is_first=nxt_is_first)
|
550 |
+
mp_output = _MessagePassingOutputChunked( # pytype: disable=wrong-arg-types # numpy-scalars
|
551 |
+
hint_preds=hint_preds,
|
552 |
+
output_preds=output_preds)
|
553 |
+
return new_mp_state, mp_output
|
554 |
+
|
555 |
+
def __call__(self, features_list: List[_FeaturesChunked],
|
556 |
+
mp_state_list: List[MessagePassingStateChunked],
|
557 |
+
repred: bool, init_mp_state: bool,
|
558 |
+
algorithm_index: int):
|
559 |
+
"""Process one chunk of data.
|
560 |
+
|
561 |
+
Args:
|
562 |
+
features_list: A list of _FeaturesChunked objects, each with the
|
563 |
+
inputs, hints and beginning- and end-of-sample markers for
|
564 |
+
a chunk (i.e., fixed time length) of data corresponding to one
|
565 |
+
algorithm. All features are expected
|
566 |
+
to have dimensions chunk_length x batch_size x ...
|
567 |
+
The list should have either length 1, at train/evaluation time,
|
568 |
+
or length equal to the number of algorithms this Net is meant to
|
569 |
+
process, at initialization.
|
570 |
+
mp_state_list: list of message-passing states. Each message-passing state
|
571 |
+
includes the inputs, hints, beginning-of-sample markers,
|
572 |
+
hint prediction, hidden and lstm state from the end of the previous
|
573 |
+
chunk, for one algorithm. The length of the list should be the same
|
574 |
+
as the length of `features_list`.
|
575 |
+
repred: False during training, when we have access to ground-truth hints.
|
576 |
+
True in validation/test mode, when we have to use our own hint
|
577 |
+
predictions.
|
578 |
+
init_mp_state: Indicates if we are calling the network just to initialise
|
579 |
+
the message-passing state, before the beginning of training or
|
580 |
+
validation. If True, `algorithm_index` (see below) must be -1 in order
|
581 |
+
to initialize the message-passing state of all algorithms.
|
582 |
+
algorithm_index: Which algorithm is being processed. It can be -1 at
|
583 |
+
initialisation (either because we are initialising the parameters of
|
584 |
+
the module or because we are intialising the message-passing state),
|
585 |
+
meaning that all algorithms should be processed, in which case
|
586 |
+
`features_list` and `mp_state_list` should have length equal to the
|
587 |
+
number of specs of the Net. Otherwise, `algorithm_index` should be
|
588 |
+
between 0 and `length(self.spec) - 1`, meaning only one of the
|
589 |
+
algorithms will be processed, and `features_list` and `mp_state_list`
|
590 |
+
should have length 1.
|
591 |
+
|
592 |
+
Returns:
|
593 |
+
A 2-tuple consisting of:
|
594 |
+
- A 2-tuple with (output predictions, hint predictions)
|
595 |
+
for the selected algorithm. Each of these has
|
596 |
+
chunk_length x batch_size x ... data, where the first time
|
597 |
+
slice contains outputs for the mp_state
|
598 |
+
that was passed as input, and the last time slice contains outputs
|
599 |
+
for the next-to-last slice of the input features. The outputs that
|
600 |
+
correspond to the final time slice of the input features will be
|
601 |
+
calculated when the next chunk is processed, using the data in the
|
602 |
+
mp_state returned here (see below). If `init_mp_state` is True,
|
603 |
+
we return None instead of the 2-tuple.
|
604 |
+
- The mp_state (message-passing state) for the next chunk of data
|
605 |
+
of the selected algorithm. If `init_mp_state` is True, we return
|
606 |
+
initial mp states for all the algorithms.
|
607 |
+
"""
|
608 |
+
if algorithm_index == -1:
|
609 |
+
algorithm_indices = range(len(features_list))
|
610 |
+
else:
|
611 |
+
algorithm_indices = [algorithm_index]
|
612 |
+
assert not init_mp_state # init state only allowed with all algorithms
|
613 |
+
assert len(algorithm_indices) == len(features_list)
|
614 |
+
assert len(algorithm_indices) == len(mp_state_list)
|
615 |
+
|
616 |
+
self.encoders, self.decoders = self._construct_encoders_decoders()
|
617 |
+
self.processor = self.processor_factory(self.hidden_dim)
|
618 |
+
# Optionally construct LSTM.
|
619 |
+
if self.use_lstm:
|
620 |
+
self.lstm = hk.LSTM(
|
621 |
+
hidden_size=self.hidden_dim,
|
622 |
+
name='processor_lstm')
|
623 |
+
lstm_init = self.lstm.initial_state
|
624 |
+
else:
|
625 |
+
self.lstm = None
|
626 |
+
lstm_init = lambda x: 0
|
627 |
+
|
628 |
+
if init_mp_state:
|
629 |
+
output_mp_states = []
|
630 |
+
for algorithm_index, features, mp_state in zip(
|
631 |
+
algorithm_indices, features_list, mp_state_list):
|
632 |
+
inputs = features.inputs
|
633 |
+
hints = features.hints
|
634 |
+
batch_size, nb_nodes = _data_dimensions_chunked(features)
|
635 |
+
|
636 |
+
if self.use_lstm:
|
637 |
+
lstm_state = lstm_init(batch_size * nb_nodes)
|
638 |
+
lstm_state = jax.tree_util.tree_map(
|
639 |
+
lambda x, b=batch_size, n=nb_nodes: jnp.reshape(x, [b, n, -1]),
|
640 |
+
lstm_state)
|
641 |
+
mp_state.lstm_state = lstm_state
|
642 |
+
mp_state.inputs = jax.tree_util.tree_map(lambda x: x[0], inputs)
|
643 |
+
mp_state.hints = jax.tree_util.tree_map(lambda x: x[0], hints)
|
644 |
+
mp_state.is_first = jnp.zeros(batch_size, dtype=int)
|
645 |
+
mp_state.hiddens = jnp.zeros((batch_size, nb_nodes, self.hidden_dim))
|
646 |
+
next_is_first = jnp.ones(batch_size, dtype=int)
|
647 |
+
|
648 |
+
mp_state, _ = self._msg_passing_step(
|
649 |
+
mp_state,
|
650 |
+
(mp_state.inputs, mp_state.hints, next_is_first),
|
651 |
+
repred=repred,
|
652 |
+
init_mp_state=True,
|
653 |
+
batch_size=batch_size,
|
654 |
+
nb_nodes=nb_nodes,
|
655 |
+
spec=self.spec[algorithm_index],
|
656 |
+
encs=self.encoders[algorithm_index],
|
657 |
+
decs=self.decoders[algorithm_index],
|
658 |
+
)
|
659 |
+
output_mp_states.append(mp_state)
|
660 |
+
return None, output_mp_states
|
661 |
+
|
662 |
+
for algorithm_index, features, mp_state in zip(
|
663 |
+
algorithm_indices, features_list, mp_state_list):
|
664 |
+
inputs = features.inputs
|
665 |
+
hints = features.hints
|
666 |
+
is_first = features.is_first
|
667 |
+
batch_size, nb_nodes = _data_dimensions_chunked(features)
|
668 |
+
|
669 |
+
scan_fn = functools.partial(
|
670 |
+
self._msg_passing_step,
|
671 |
+
repred=repred,
|
672 |
+
init_mp_state=False,
|
673 |
+
batch_size=batch_size,
|
674 |
+
nb_nodes=nb_nodes,
|
675 |
+
spec=self.spec[algorithm_index],
|
676 |
+
encs=self.encoders[algorithm_index],
|
677 |
+
decs=self.decoders[algorithm_index],
|
678 |
+
)
|
679 |
+
|
680 |
+
mp_state, scan_output = hk.scan(
|
681 |
+
scan_fn,
|
682 |
+
mp_state,
|
683 |
+
(inputs, hints, is_first),
|
684 |
+
)
|
685 |
+
|
686 |
+
# We only return the last algorithm's output and state. That's because
|
687 |
+
# the output only matters when a single algorithm is processed; the case
|
688 |
+
# `algorithm_index==-1` (meaning all algorithms should be processed)
|
689 |
+
# is used only to init parameters.
|
690 |
+
return (scan_output.output_preds, scan_output.hint_preds), mp_state
|
691 |
+
|
692 |
+
|
693 |
+
def _data_dimensions(features: _Features) -> Tuple[int, int]:
|
694 |
+
"""Returns (batch_size, nb_nodes)."""
|
695 |
+
for inp in features.inputs:
|
696 |
+
if inp.location in [_Location.NODE, _Location.EDGE]:
|
697 |
+
return inp.data.shape[:2]
|
698 |
+
assert False
|
699 |
+
|
700 |
+
|
701 |
+
def _data_dimensions_chunked(features: _FeaturesChunked) -> Tuple[int, int]:
|
702 |
+
"""Returns (batch_size, nb_nodes)."""
|
703 |
+
for inp in features.inputs:
|
704 |
+
if inp.location in [_Location.NODE, _Location.EDGE]:
|
705 |
+
return inp.data.shape[1:3]
|
706 |
+
assert False
|
707 |
+
|
708 |
+
|
709 |
+
def _expand_to(x: _Array, y: _Array) -> _Array:
|
710 |
+
while len(y.shape) > len(x.shape):
|
711 |
+
x = jnp.expand_dims(x, -1)
|
712 |
+
return x
|
713 |
+
|
714 |
+
|
715 |
+
def _is_not_done_broadcast(lengths, i, tensor):
|
716 |
+
is_not_done = (lengths > i + 1) * 1.0
|
717 |
+
while len(is_not_done.shape) < len(tensor.shape): # pytype: disable=attribute-error # numpy-scalars
|
718 |
+
is_not_done = jnp.expand_dims(is_not_done, -1)
|
719 |
+
return is_not_done
|
benchmarks/CLRS/env/probing.py
ADDED
@@ -0,0 +1,351 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Probing utilities.
|
17 |
+
|
18 |
+
The dataflow for an algorithm is represented by `(stage, loc, type, data)`
|
19 |
+
"probes" that are valid under that algorithm's spec (see `specs.py`).
|
20 |
+
|
21 |
+
When constructing probes, it is convenient to represent these fields in a nested
|
22 |
+
format (`ProbesDict`) to facilate efficient contest-based look-up.
|
23 |
+
|
24 |
+
"""
|
25 |
+
|
26 |
+
import functools
|
27 |
+
from typing import Dict, List, Tuple, Union
|
28 |
+
|
29 |
+
import attr
|
30 |
+
from clrs._src import specs
|
31 |
+
import jax
|
32 |
+
import jax.numpy as jnp
|
33 |
+
import numpy as np
|
34 |
+
import tensorflow as tf
|
35 |
+
|
36 |
+
|
37 |
+
_Location = specs.Location
|
38 |
+
_Stage = specs.Stage
|
39 |
+
_Type = specs.Type
|
40 |
+
_OutputClass = specs.OutputClass
|
41 |
+
|
42 |
+
_Array = np.ndarray
|
43 |
+
_Data = Union[_Array, List[_Array]]
|
44 |
+
_DataOrType = Union[_Data, str]
|
45 |
+
|
46 |
+
ProbesDict = Dict[
|
47 |
+
str, Dict[str, Dict[str, Dict[str, _DataOrType]]]]
|
48 |
+
|
49 |
+
|
50 |
+
def _convert_to_str(element):
|
51 |
+
if isinstance(element, tf.Tensor):
|
52 |
+
return element.numpy().decode('utf-8')
|
53 |
+
elif isinstance(element, (np.ndarray, bytes)):
|
54 |
+
return element.decode('utf-8')
|
55 |
+
else:
|
56 |
+
return element
|
57 |
+
|
58 |
+
|
59 |
+
# First anotation makes this object jax.jit/pmap friendly, second one makes this
|
60 |
+
# tf.data.Datasets friendly.
|
61 |
+
@jax.tree_util.register_pytree_node_class
|
62 |
+
@attr.define
|
63 |
+
class DataPoint:
|
64 |
+
"""Describes a data point."""
|
65 |
+
|
66 |
+
_name: str
|
67 |
+
_location: str
|
68 |
+
_type_: str
|
69 |
+
data: _Array
|
70 |
+
|
71 |
+
@property
|
72 |
+
def name(self):
|
73 |
+
return _convert_to_str(self._name)
|
74 |
+
|
75 |
+
@property
|
76 |
+
def location(self):
|
77 |
+
return _convert_to_str(self._location)
|
78 |
+
|
79 |
+
@property
|
80 |
+
def type_(self):
|
81 |
+
return _convert_to_str(self._type_)
|
82 |
+
|
83 |
+
def __repr__(self):
|
84 |
+
s = f'DataPoint(name="{self.name}",\tlocation={self.location},\t'
|
85 |
+
return s + f'type={self.type_},\tdata=Array{self.data.shape})'
|
86 |
+
|
87 |
+
def tree_flatten(self):
|
88 |
+
data = (self.data,)
|
89 |
+
meta = (self.name, self.location, self.type_)
|
90 |
+
return data, meta
|
91 |
+
|
92 |
+
@classmethod
|
93 |
+
def tree_unflatten(cls, meta, data):
|
94 |
+
name, location, type_ = meta
|
95 |
+
subdata, = data
|
96 |
+
return DataPoint(name, location, type_, subdata)
|
97 |
+
|
98 |
+
|
99 |
+
class ProbeError(Exception):
|
100 |
+
pass
|
101 |
+
|
102 |
+
|
103 |
+
def initialize(spec: specs.Spec) -> ProbesDict:
|
104 |
+
"""Initializes an empty `ProbesDict` corresponding with the provided spec."""
|
105 |
+
probes = dict()
|
106 |
+
for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]:
|
107 |
+
probes[stage] = {}
|
108 |
+
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
|
109 |
+
probes[stage][loc] = {}
|
110 |
+
|
111 |
+
for name in spec:
|
112 |
+
stage, loc, t = spec[name]
|
113 |
+
probes[stage][loc][name] = {}
|
114 |
+
probes[stage][loc][name]['data'] = []
|
115 |
+
probes[stage][loc][name]['type_'] = t
|
116 |
+
# Pytype thinks initialize() returns a ProbesDict with a str for all final
|
117 |
+
# values instead of _DataOrType.
|
118 |
+
return probes # pytype: disable=bad-return-type
|
119 |
+
|
120 |
+
|
121 |
+
def push(probes: ProbesDict, stage: str, next_probe):
|
122 |
+
"""Pushes a probe into an existing `ProbesDict`."""
|
123 |
+
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
|
124 |
+
for name in probes[stage][loc]:
|
125 |
+
if name not in next_probe:
|
126 |
+
raise ProbeError(f'Missing probe for {name}.')
|
127 |
+
if isinstance(probes[stage][loc][name]['data'], _Array):
|
128 |
+
raise ProbeError('Attemping to push to finalized `ProbesDict`.')
|
129 |
+
# Pytype thinks initialize() returns a ProbesDict with a str for all final
|
130 |
+
# values instead of _DataOrType.
|
131 |
+
probes[stage][loc][name]['data'].append(next_probe[name]) # pytype: disable=attribute-error
|
132 |
+
|
133 |
+
|
134 |
+
def finalize(probes: ProbesDict):
|
135 |
+
"""Finalizes a `ProbesDict` by stacking/squeezing `data` field."""
|
136 |
+
for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]:
|
137 |
+
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]:
|
138 |
+
for name in probes[stage][loc]:
|
139 |
+
if isinstance(probes[stage][loc][name]['data'], _Array):
|
140 |
+
raise ProbeError('Attemping to re-finalize a finalized `ProbesDict`.')
|
141 |
+
if stage == _Stage.HINT:
|
142 |
+
# Hints are provided for each timestep. Stack them here.
|
143 |
+
probes[stage][loc][name]['data'] = np.stack(
|
144 |
+
probes[stage][loc][name]['data'])
|
145 |
+
else:
|
146 |
+
# Only one instance of input/output exist. Remove leading axis.
|
147 |
+
probes[stage][loc][name]['data'] = np.squeeze(
|
148 |
+
np.array(probes[stage][loc][name]['data']))
|
149 |
+
|
150 |
+
|
151 |
+
def split_stages(
|
152 |
+
probes: ProbesDict,
|
153 |
+
spec: specs.Spec,
|
154 |
+
) -> Tuple[List[DataPoint], List[DataPoint], List[DataPoint]]:
|
155 |
+
"""Splits contents of `ProbesDict` into `DataPoint`s by stage."""
|
156 |
+
|
157 |
+
inputs = []
|
158 |
+
outputs = []
|
159 |
+
hints = []
|
160 |
+
|
161 |
+
for name in spec:
|
162 |
+
stage, loc, t = spec[name]
|
163 |
+
|
164 |
+
if stage not in probes:
|
165 |
+
raise ProbeError(f'Missing stage {stage}.')
|
166 |
+
if loc not in probes[stage]:
|
167 |
+
raise ProbeError(f'Missing location {loc}.')
|
168 |
+
if name not in probes[stage][loc]:
|
169 |
+
raise ProbeError(f'Missing probe {name}.')
|
170 |
+
if 'type_' not in probes[stage][loc][name]:
|
171 |
+
raise ProbeError(f'Probe {name} missing attribute `type_`.')
|
172 |
+
if 'data' not in probes[stage][loc][name]:
|
173 |
+
raise ProbeError(f'Probe {name} missing attribute `data`.')
|
174 |
+
if t != probes[stage][loc][name]['type_']:
|
175 |
+
raise ProbeError(f'Probe {name} of incorrect type {t}.')
|
176 |
+
|
177 |
+
data = probes[stage][loc][name]['data']
|
178 |
+
if not isinstance(probes[stage][loc][name]['data'], _Array):
|
179 |
+
raise ProbeError((f'Invalid `data` for probe "{name}". ' +
|
180 |
+
'Did you forget to call `probing.finalize`?'))
|
181 |
+
|
182 |
+
if t in [_Type.MASK, _Type.MASK_ONE, _Type.CATEGORICAL]:
|
183 |
+
# pytype: disable=attribute-error
|
184 |
+
if not ((data == 0) | (data == 1) | (data == -1)).all():
|
185 |
+
raise ProbeError(f'0|1|-1 `data` for probe "{name}"')
|
186 |
+
# pytype: enable=attribute-error
|
187 |
+
if t in [_Type.MASK_ONE, _Type.CATEGORICAL
|
188 |
+
] and not np.all(np.sum(np.abs(data), -1) == 1):
|
189 |
+
raise ProbeError(f'Expected one-hot `data` for probe "{name}"')
|
190 |
+
|
191 |
+
dim_to_expand = 1 if stage == _Stage.HINT else 0
|
192 |
+
data_point = DataPoint(name=name, location=loc, type_=t,
|
193 |
+
data=np.expand_dims(data, dim_to_expand))
|
194 |
+
|
195 |
+
if stage == _Stage.INPUT:
|
196 |
+
inputs.append(data_point)
|
197 |
+
elif stage == _Stage.OUTPUT:
|
198 |
+
outputs.append(data_point)
|
199 |
+
else:
|
200 |
+
hints.append(data_point)
|
201 |
+
|
202 |
+
return inputs, outputs, hints
|
203 |
+
|
204 |
+
|
205 |
+
# pylint: disable=invalid-name
|
206 |
+
|
207 |
+
|
208 |
+
def array(A_pos: np.ndarray) -> np.ndarray:
|
209 |
+
"""Constructs an `array` probe."""
|
210 |
+
probe = np.arange(A_pos.shape[0])
|
211 |
+
for i in range(1, A_pos.shape[0]):
|
212 |
+
probe[A_pos[i]] = A_pos[i - 1]
|
213 |
+
return probe
|
214 |
+
|
215 |
+
|
216 |
+
def array_cat(A: np.ndarray, n: int) -> np.ndarray:
|
217 |
+
"""Constructs an `array_cat` probe."""
|
218 |
+
assert n > 0
|
219 |
+
probe = np.zeros((A.shape[0], n))
|
220 |
+
for i in range(A.shape[0]):
|
221 |
+
probe[i, A[i]] = 1
|
222 |
+
return probe
|
223 |
+
|
224 |
+
|
225 |
+
def heap(A_pos: np.ndarray, heap_size: int) -> np.ndarray:
|
226 |
+
"""Constructs a `heap` probe."""
|
227 |
+
assert heap_size > 0
|
228 |
+
probe = np.arange(A_pos.shape[0])
|
229 |
+
for i in range(1, heap_size):
|
230 |
+
probe[A_pos[i]] = A_pos[(i - 1) // 2]
|
231 |
+
return probe
|
232 |
+
|
233 |
+
|
234 |
+
def graph(A: np.ndarray) -> np.ndarray:
|
235 |
+
"""Constructs a `graph` probe."""
|
236 |
+
probe = (A != 0) * 1.0
|
237 |
+
probe = ((A + np.eye(A.shape[0])) != 0) * 1.0
|
238 |
+
return probe
|
239 |
+
|
240 |
+
|
241 |
+
def mask_one(i: int, n: int) -> np.ndarray:
|
242 |
+
"""Constructs a `mask_one` probe."""
|
243 |
+
assert n > i
|
244 |
+
probe = np.zeros(n)
|
245 |
+
probe[i] = 1
|
246 |
+
return probe
|
247 |
+
|
248 |
+
|
249 |
+
def strings_id(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray:
|
250 |
+
"""Constructs a `strings_id` probe."""
|
251 |
+
probe_T = np.zeros(T_pos.shape[0])
|
252 |
+
probe_P = np.ones(P_pos.shape[0])
|
253 |
+
return np.concatenate([probe_T, probe_P])
|
254 |
+
|
255 |
+
|
256 |
+
def strings_pair(pair_probe: np.ndarray) -> np.ndarray:
|
257 |
+
"""Constructs a `strings_pair` probe."""
|
258 |
+
n = pair_probe.shape[0]
|
259 |
+
m = pair_probe.shape[1]
|
260 |
+
probe_ret = np.zeros((n + m, n + m))
|
261 |
+
for i in range(0, n):
|
262 |
+
for j in range(0, m):
|
263 |
+
probe_ret[i, j + n] = pair_probe[i, j]
|
264 |
+
return probe_ret
|
265 |
+
|
266 |
+
|
267 |
+
def strings_pair_cat(pair_probe: np.ndarray, nb_classes: int) -> np.ndarray:
|
268 |
+
"""Constructs a `strings_pair_cat` probe."""
|
269 |
+
assert nb_classes > 0
|
270 |
+
n = pair_probe.shape[0]
|
271 |
+
m = pair_probe.shape[1]
|
272 |
+
|
273 |
+
# Add an extra class for 'this cell left blank.'
|
274 |
+
probe_ret = np.zeros((n + m, n + m, nb_classes + 1))
|
275 |
+
for i in range(0, n):
|
276 |
+
for j in range(0, m):
|
277 |
+
probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE
|
278 |
+
|
279 |
+
# Fill the blank cells.
|
280 |
+
for i_1 in range(0, n):
|
281 |
+
for i_2 in range(0, n):
|
282 |
+
probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED
|
283 |
+
for j_1 in range(0, m):
|
284 |
+
for x in range(0, n + m):
|
285 |
+
probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED
|
286 |
+
return probe_ret
|
287 |
+
|
288 |
+
|
289 |
+
def strings_pi(T_pos: np.ndarray, P_pos: np.ndarray,
|
290 |
+
pi: np.ndarray) -> np.ndarray:
|
291 |
+
"""Constructs a `strings_pi` probe."""
|
292 |
+
probe = np.arange(T_pos.shape[0] + P_pos.shape[0])
|
293 |
+
for j in range(P_pos.shape[0]):
|
294 |
+
probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + pi[P_pos[j]]
|
295 |
+
return probe
|
296 |
+
|
297 |
+
|
298 |
+
def strings_pos(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray:
|
299 |
+
"""Constructs a `strings_pos` probe."""
|
300 |
+
probe_T = np.copy(T_pos) * 1.0 / T_pos.shape[0]
|
301 |
+
probe_P = np.copy(P_pos) * 1.0 / P_pos.shape[0]
|
302 |
+
return np.concatenate([probe_T, probe_P])
|
303 |
+
|
304 |
+
|
305 |
+
def strings_pred(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray:
|
306 |
+
"""Constructs a `strings_pred` probe."""
|
307 |
+
probe = np.arange(T_pos.shape[0] + P_pos.shape[0])
|
308 |
+
for i in range(1, T_pos.shape[0]):
|
309 |
+
probe[T_pos[i]] = T_pos[i - 1]
|
310 |
+
for j in range(1, P_pos.shape[0]):
|
311 |
+
probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + P_pos[j - 1]
|
312 |
+
return probe
|
313 |
+
|
314 |
+
|
315 |
+
@functools.partial(jnp.vectorize, signature='(n)->(n,n),(n)')
|
316 |
+
def predecessor_to_cyclic_predecessor_and_first(
|
317 |
+
pointers: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
318 |
+
"""Converts predecessor pointers to cyclic predecessor + first node mask.
|
319 |
+
|
320 |
+
This function assumes that the pointers represent a linear order of the nodes
|
321 |
+
(akin to a linked list), where each node points to its predecessor and the
|
322 |
+
first node points to itself. It returns the same pointers, except that
|
323 |
+
the first node points to the last, and a mask_one marking the first node.
|
324 |
+
|
325 |
+
Example:
|
326 |
+
```
|
327 |
+
pointers = [2, 1, 1]
|
328 |
+
P = [[0, 0, 1],
|
329 |
+
[1, 0, 0],
|
330 |
+
[0, 1, 0]],
|
331 |
+
M = [0, 1, 0]
|
332 |
+
```
|
333 |
+
|
334 |
+
Args:
|
335 |
+
pointers: array of shape [N] containing pointers. The pointers are assumed
|
336 |
+
to describe a linear order such that `pointers[i]` is the predecessor
|
337 |
+
of node `i`.
|
338 |
+
|
339 |
+
Returns:
|
340 |
+
Permutation pointers `P` of shape [N] and one-hot vector `M` of shape [N].
|
341 |
+
"""
|
342 |
+
nb_nodes = pointers.shape[-1]
|
343 |
+
pointers_one_hot = jax.nn.one_hot(pointers, nb_nodes)
|
344 |
+
# Find the index of the last node: it's the node that no other node points to.
|
345 |
+
last = pointers_one_hot.sum(-2).argmin()
|
346 |
+
# Find the first node: should be the only one pointing to itself.
|
347 |
+
first = pointers_one_hot.diagonal().argmax()
|
348 |
+
mask = jax.nn.one_hot(first, nb_nodes)
|
349 |
+
pointers_one_hot += mask[..., None] * jax.nn.one_hot(last, nb_nodes)
|
350 |
+
pointers_one_hot -= mask[..., None] * mask
|
351 |
+
return pointers_one_hot, mask
|
benchmarks/CLRS/env/probing_test.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Unit tests for `probing.py`."""
|
17 |
+
|
18 |
+
from absl.testing import absltest
|
19 |
+
|
20 |
+
from clrs._src import probing
|
21 |
+
import jax.numpy as jnp
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
|
25 |
+
# pylint: disable=invalid-name
|
26 |
+
|
27 |
+
|
28 |
+
class ProbingTest(absltest.TestCase):
|
29 |
+
|
30 |
+
def test_array(self):
|
31 |
+
A_pos = np.array([1, 2, 0, 4, 3])
|
32 |
+
expected = np.array([2, 1, 1, 4, 0])
|
33 |
+
out = probing.array(A_pos)
|
34 |
+
np.testing.assert_array_equal(expected, out)
|
35 |
+
|
36 |
+
def test_array_cat(self):
|
37 |
+
A = np.array([2, 1, 0, 1, 1])
|
38 |
+
expected = np.array([
|
39 |
+
[0, 0, 1],
|
40 |
+
[0, 1, 0],
|
41 |
+
[1, 0, 0],
|
42 |
+
[0, 1, 0],
|
43 |
+
[0, 1, 0]
|
44 |
+
])
|
45 |
+
out = probing.array_cat(A, 3)
|
46 |
+
np.testing.assert_array_equal(expected, out)
|
47 |
+
|
48 |
+
def test_heap(self):
|
49 |
+
A_pos = np.array([1, 3, 5, 0, 7, 4, 2, 6])
|
50 |
+
expected = np.array([3, 1, 2, 1, 5, 1, 6, 3])
|
51 |
+
out = probing.heap(A_pos, heap_size=6)
|
52 |
+
np.testing.assert_array_equal(expected, out)
|
53 |
+
|
54 |
+
def test_graph(self):
|
55 |
+
G = np.array([
|
56 |
+
[0.0, 7.0, -1.0, -3.9, 7.452],
|
57 |
+
[0.0, 0.0, 133.0, 0.0, 9.3],
|
58 |
+
[0.5, 0.1, 0.22, 0.55, 0.666],
|
59 |
+
[7.0, 6.1, 0.2, 0.0, 0.0],
|
60 |
+
[0.0, 3.0, 0.0, 1.0, 0.5]
|
61 |
+
])
|
62 |
+
expected = np.array([
|
63 |
+
[1.0, 1.0, 1.0, 1.0, 1.0],
|
64 |
+
[0.0, 1.0, 1.0, 0.0, 1.0],
|
65 |
+
[1.0, 1.0, 1.0, 1.0, 1.0],
|
66 |
+
[1.0, 1.0, 1.0, 1.0, 0.0],
|
67 |
+
[0.0, 1.0, 0.0, 1.0, 1.0]
|
68 |
+
])
|
69 |
+
out = probing.graph(G)
|
70 |
+
np.testing.assert_array_equal(expected, out)
|
71 |
+
|
72 |
+
def test_mask_one(self):
|
73 |
+
expected = np.array([0, 0, 0, 1, 0])
|
74 |
+
out = probing.mask_one(3, 5)
|
75 |
+
np.testing.assert_array_equal(expected, out)
|
76 |
+
|
77 |
+
def test_strings_id(self):
|
78 |
+
T_pos = np.array([0, 1, 2, 3, 4])
|
79 |
+
P_pos = np.array([0, 1, 2])
|
80 |
+
expected = np.array([0, 0, 0, 0, 0, 1, 1, 1])
|
81 |
+
out = probing.strings_id(T_pos, P_pos)
|
82 |
+
np.testing.assert_array_equal(expected, out)
|
83 |
+
|
84 |
+
def test_strings_pair(self):
|
85 |
+
pair_probe = np.array([
|
86 |
+
[0.5, 3.1, 9.1, 7.3],
|
87 |
+
[1.0, 0.0, 8.0, 9.3],
|
88 |
+
[0.1, 5.0, 0.0, 1.2]
|
89 |
+
])
|
90 |
+
expected = np.array([
|
91 |
+
[0.0, 0.0, 0.0, 0.5, 3.1, 9.1, 7.3],
|
92 |
+
[0.0, 0.0, 0.0, 1.0, 0.0, 8.0, 9.3],
|
93 |
+
[0.0, 0.0, 0.0, 0.1, 5.0, 0.0, 1.2],
|
94 |
+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
95 |
+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
96 |
+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
|
97 |
+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
|
98 |
+
])
|
99 |
+
out = probing.strings_pair(pair_probe)
|
100 |
+
np.testing.assert_equal(expected, out)
|
101 |
+
|
102 |
+
def test_strings_pair_cat(self):
|
103 |
+
pair_probe = np.array([
|
104 |
+
[0, 2, 1],
|
105 |
+
[2, 2, 0]
|
106 |
+
])
|
107 |
+
expected = np.array([
|
108 |
+
[
|
109 |
+
[0, 0, 0, -1],
|
110 |
+
[0, 0, 0, -1],
|
111 |
+
[1, 0, 0, 0],
|
112 |
+
[0, 0, 1, 0],
|
113 |
+
[0, 1, 0, 0],
|
114 |
+
],
|
115 |
+
[
|
116 |
+
[0, 0, 0, -1],
|
117 |
+
[0, 0, 0, -1],
|
118 |
+
[0, 0, 1, 0],
|
119 |
+
[0, 0, 1, 0],
|
120 |
+
[1, 0, 0, 0],
|
121 |
+
],
|
122 |
+
[
|
123 |
+
[0, 0, 0, -1],
|
124 |
+
[0, 0, 0, -1],
|
125 |
+
[0, 0, 0, -1],
|
126 |
+
[0, 0, 0, -1],
|
127 |
+
[0, 0, 0, -1],
|
128 |
+
],
|
129 |
+
[
|
130 |
+
[0, 0, 0, -1],
|
131 |
+
[0, 0, 0, -1],
|
132 |
+
[0, 0, 0, -1],
|
133 |
+
[0, 0, 0, -1],
|
134 |
+
[0, 0, 0, -1],
|
135 |
+
],
|
136 |
+
[
|
137 |
+
[0, 0, 0, -1],
|
138 |
+
[0, 0, 0, -1],
|
139 |
+
[0, 0, 0, -1],
|
140 |
+
[0, 0, 0, -1],
|
141 |
+
[0, 0, 0, -1],
|
142 |
+
],
|
143 |
+
])
|
144 |
+
out = probing.strings_pair_cat(pair_probe, 3)
|
145 |
+
np.testing.assert_equal(expected, out)
|
146 |
+
|
147 |
+
def test_strings_pi(self):
|
148 |
+
T_pos = np.array([0, 1, 2, 3, 4, 5])
|
149 |
+
P_pos = np.array([0, 1, 2, 3])
|
150 |
+
pi = np.array([3, 1, 0, 2])
|
151 |
+
expected = np.array(
|
152 |
+
[0, 1, 2, 3, 4, 5, 9, 7, 6, 8]
|
153 |
+
)
|
154 |
+
out = probing.strings_pi(T_pos, P_pos, pi)
|
155 |
+
np.testing.assert_array_equal(expected, out)
|
156 |
+
|
157 |
+
def test_strings_pos(self):
|
158 |
+
T_pos = np.array([0, 1, 2, 3, 4])
|
159 |
+
P_pos = np.array([0, 1, 2, 3])
|
160 |
+
expected = np.array(
|
161 |
+
[0.0, 0.2, 0.4, 0.6, 0.8,
|
162 |
+
0.0, 0.25, 0.5, 0.75]
|
163 |
+
)
|
164 |
+
out = probing.strings_pos(T_pos, P_pos)
|
165 |
+
np.testing.assert_array_equal(expected, out)
|
166 |
+
|
167 |
+
def test_strings_pred(self):
|
168 |
+
T_pos = np.array([0, 1, 2, 3, 4])
|
169 |
+
P_pos = np.array([0, 1, 2])
|
170 |
+
expected = np.array([0, 0, 1, 2, 3, 5, 5, 6])
|
171 |
+
out = probing.strings_pred(T_pos, P_pos)
|
172 |
+
np.testing.assert_array_equal(expected, out)
|
173 |
+
|
174 |
+
|
175 |
+
class PermutationsTest(absltest.TestCase):
|
176 |
+
|
177 |
+
def test_pointers_to_permutation(self):
|
178 |
+
pointers = jnp.array([2, 1, 1])
|
179 |
+
perm, first = probing.predecessor_to_cyclic_predecessor_and_first(pointers)
|
180 |
+
np.testing.assert_array_equal(
|
181 |
+
perm, np.array([[0, 0, 1], [1, 0, 0], [0, 1, 0]]))
|
182 |
+
np.testing.assert_array_equal(first, np.array([0, 1, 0]))
|
183 |
+
|
184 |
+
def test_pointers_to_permutation_already_sorted(self):
|
185 |
+
pointers = jnp.array([0, 0, 1, 2, 3, 4])
|
186 |
+
perm, first = probing.predecessor_to_cyclic_predecessor_and_first(pointers)
|
187 |
+
np.testing.assert_array_equal(perm, np.roll(np.eye(6), 1, 0))
|
188 |
+
np.testing.assert_array_equal(first, np.eye(6)[0])
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
absltest.main()
|
benchmarks/CLRS/env/processors.py
ADDED
@@ -0,0 +1,856 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""JAX implementation of baseline processor networks."""
|
17 |
+
|
18 |
+
import abc
|
19 |
+
from typing import Any, Callable, List, Optional, Tuple
|
20 |
+
|
21 |
+
import chex
|
22 |
+
import haiku as hk
|
23 |
+
import jax
|
24 |
+
import jax.numpy as jnp
|
25 |
+
import numpy as np
|
26 |
+
|
27 |
+
|
28 |
+
_Array = chex.Array
|
29 |
+
_Fn = Callable[..., Any]
|
30 |
+
BIG_NUMBER = 1e6
|
31 |
+
PROCESSOR_TAG = 'clrs_processor'
|
32 |
+
|
33 |
+
|
34 |
+
class Processor(hk.Module):
|
35 |
+
"""Processor abstract base class."""
|
36 |
+
|
37 |
+
def __init__(self, name: str):
|
38 |
+
if not name.endswith(PROCESSOR_TAG):
|
39 |
+
name = name + '_' + PROCESSOR_TAG
|
40 |
+
super().__init__(name=name)
|
41 |
+
|
42 |
+
@abc.abstractmethod
|
43 |
+
def __call__(
|
44 |
+
self,
|
45 |
+
node_fts: _Array,
|
46 |
+
edge_fts: _Array,
|
47 |
+
graph_fts: _Array,
|
48 |
+
adj_mat: _Array,
|
49 |
+
hidden: _Array,
|
50 |
+
**kwargs,
|
51 |
+
) -> Tuple[_Array, Optional[_Array]]:
|
52 |
+
"""Processor inference step.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
node_fts: Node features.
|
56 |
+
edge_fts: Edge features.
|
57 |
+
graph_fts: Graph features.
|
58 |
+
adj_mat: Graph adjacency matrix.
|
59 |
+
hidden: Hidden features.
|
60 |
+
**kwargs: Extra kwargs.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
Output of processor inference step as a 2-tuple of (node, edge)
|
64 |
+
embeddings. The edge embeddings can be None.
|
65 |
+
"""
|
66 |
+
pass
|
67 |
+
|
68 |
+
@property
|
69 |
+
def inf_bias(self):
|
70 |
+
return False
|
71 |
+
|
72 |
+
@property
|
73 |
+
def inf_bias_edge(self):
|
74 |
+
return False
|
75 |
+
|
76 |
+
|
77 |
+
class GAT(Processor):
|
78 |
+
"""Graph Attention Network (Velickovic et al., ICLR 2018)."""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
out_size: int,
|
83 |
+
nb_heads: int,
|
84 |
+
activation: Optional[_Fn] = jax.nn.relu,
|
85 |
+
residual: bool = True,
|
86 |
+
use_ln: bool = False,
|
87 |
+
name: str = 'gat_aggr',
|
88 |
+
):
|
89 |
+
super().__init__(name=name)
|
90 |
+
self.out_size = out_size
|
91 |
+
self.nb_heads = nb_heads
|
92 |
+
if out_size % nb_heads != 0:
|
93 |
+
raise ValueError('The number of attention heads must divide the width!')
|
94 |
+
self.head_size = out_size // nb_heads
|
95 |
+
self.activation = activation
|
96 |
+
self.residual = residual
|
97 |
+
self.use_ln = use_ln
|
98 |
+
|
99 |
+
def __call__( # pytype: disable=signature-mismatch # numpy-scalars
|
100 |
+
self,
|
101 |
+
node_fts: _Array,
|
102 |
+
edge_fts: _Array,
|
103 |
+
graph_fts: _Array,
|
104 |
+
adj_mat: _Array,
|
105 |
+
hidden: _Array,
|
106 |
+
**unused_kwargs,
|
107 |
+
) -> _Array:
|
108 |
+
"""GAT inference step."""
|
109 |
+
|
110 |
+
b, n, _ = node_fts.shape
|
111 |
+
assert edge_fts.shape[:-1] == (b, n, n)
|
112 |
+
assert graph_fts.shape[:-1] == (b,)
|
113 |
+
assert adj_mat.shape == (b, n, n)
|
114 |
+
|
115 |
+
z = jnp.concatenate([node_fts, hidden], axis=-1)
|
116 |
+
m = hk.Linear(self.out_size)
|
117 |
+
skip = hk.Linear(self.out_size)
|
118 |
+
|
119 |
+
bias_mat = (adj_mat - 1.0) * 1e9
|
120 |
+
bias_mat = jnp.tile(bias_mat[..., None],
|
121 |
+
(1, 1, 1, self.nb_heads)) # [B, N, N, H]
|
122 |
+
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N]
|
123 |
+
|
124 |
+
a_1 = hk.Linear(self.nb_heads)
|
125 |
+
a_2 = hk.Linear(self.nb_heads)
|
126 |
+
a_e = hk.Linear(self.nb_heads)
|
127 |
+
a_g = hk.Linear(self.nb_heads)
|
128 |
+
|
129 |
+
values = m(z) # [B, N, H*F]
|
130 |
+
values = jnp.reshape(
|
131 |
+
values,
|
132 |
+
values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F]
|
133 |
+
values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F]
|
134 |
+
|
135 |
+
att_1 = jnp.expand_dims(a_1(z), axis=-1)
|
136 |
+
att_2 = jnp.expand_dims(a_2(z), axis=-1)
|
137 |
+
att_e = a_e(edge_fts)
|
138 |
+
att_g = jnp.expand_dims(a_g(graph_fts), axis=-1)
|
139 |
+
|
140 |
+
logits = (
|
141 |
+
jnp.transpose(att_1, (0, 2, 1, 3)) + # + [B, H, N, 1]
|
142 |
+
jnp.transpose(att_2, (0, 2, 3, 1)) + # + [B, H, 1, N]
|
143 |
+
jnp.transpose(att_e, (0, 3, 1, 2)) + # + [B, H, N, N]
|
144 |
+
jnp.expand_dims(att_g, axis=-1) # + [B, H, 1, 1]
|
145 |
+
) # = [B, H, N, N]
|
146 |
+
coefs = jax.nn.softmax(jax.nn.leaky_relu(logits) + bias_mat, axis=-1)
|
147 |
+
ret = jnp.matmul(coefs, values) # [B, H, N, F]
|
148 |
+
ret = jnp.transpose(ret, (0, 2, 1, 3)) # [B, N, H, F]
|
149 |
+
ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F]
|
150 |
+
|
151 |
+
if self.residual:
|
152 |
+
ret += skip(z)
|
153 |
+
|
154 |
+
if self.activation is not None:
|
155 |
+
ret = self.activation(ret)
|
156 |
+
|
157 |
+
if self.use_ln:
|
158 |
+
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
|
159 |
+
ret = ln(ret)
|
160 |
+
|
161 |
+
return ret, None # pytype: disable=bad-return-type # numpy-scalars
|
162 |
+
|
163 |
+
|
164 |
+
class GATFull(GAT):
|
165 |
+
"""Graph Attention Network with full adjacency matrix."""
|
166 |
+
|
167 |
+
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
|
168 |
+
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
|
169 |
+
adj_mat = jnp.ones_like(adj_mat)
|
170 |
+
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
|
171 |
+
|
172 |
+
|
173 |
+
class GATv2(Processor):
|
174 |
+
"""Graph Attention Network v2 (Brody et al., ICLR 2022)."""
|
175 |
+
|
176 |
+
def __init__(
|
177 |
+
self,
|
178 |
+
out_size: int,
|
179 |
+
nb_heads: int,
|
180 |
+
mid_size: Optional[int] = None,
|
181 |
+
activation: Optional[_Fn] = jax.nn.relu,
|
182 |
+
residual: bool = True,
|
183 |
+
use_ln: bool = False,
|
184 |
+
name: str = 'gatv2_aggr',
|
185 |
+
):
|
186 |
+
super().__init__(name=name)
|
187 |
+
if mid_size is None:
|
188 |
+
self.mid_size = out_size
|
189 |
+
else:
|
190 |
+
self.mid_size = mid_size
|
191 |
+
self.out_size = out_size
|
192 |
+
self.nb_heads = nb_heads
|
193 |
+
if out_size % nb_heads != 0:
|
194 |
+
raise ValueError('The number of attention heads must divide the width!')
|
195 |
+
self.head_size = out_size // nb_heads
|
196 |
+
if self.mid_size % nb_heads != 0:
|
197 |
+
raise ValueError('The number of attention heads must divide the message!')
|
198 |
+
self.mid_head_size = self.mid_size // nb_heads
|
199 |
+
self.activation = activation
|
200 |
+
self.residual = residual
|
201 |
+
self.use_ln = use_ln
|
202 |
+
|
203 |
+
def __call__( # pytype: disable=signature-mismatch # numpy-scalars
|
204 |
+
self,
|
205 |
+
node_fts: _Array,
|
206 |
+
edge_fts: _Array,
|
207 |
+
graph_fts: _Array,
|
208 |
+
adj_mat: _Array,
|
209 |
+
hidden: _Array,
|
210 |
+
**unused_kwargs,
|
211 |
+
) -> _Array:
|
212 |
+
"""GATv2 inference step."""
|
213 |
+
|
214 |
+
b, n, _ = node_fts.shape
|
215 |
+
assert edge_fts.shape[:-1] == (b, n, n)
|
216 |
+
assert graph_fts.shape[:-1] == (b,)
|
217 |
+
assert adj_mat.shape == (b, n, n)
|
218 |
+
|
219 |
+
z = jnp.concatenate([node_fts, hidden], axis=-1)
|
220 |
+
m = hk.Linear(self.out_size)
|
221 |
+
skip = hk.Linear(self.out_size)
|
222 |
+
|
223 |
+
bias_mat = (adj_mat - 1.0) * 1e9
|
224 |
+
bias_mat = jnp.tile(bias_mat[..., None],
|
225 |
+
(1, 1, 1, self.nb_heads)) # [B, N, N, H]
|
226 |
+
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) # [B, H, N, N]
|
227 |
+
|
228 |
+
w_1 = hk.Linear(self.mid_size)
|
229 |
+
w_2 = hk.Linear(self.mid_size)
|
230 |
+
w_e = hk.Linear(self.mid_size)
|
231 |
+
w_g = hk.Linear(self.mid_size)
|
232 |
+
|
233 |
+
a_heads = []
|
234 |
+
for _ in range(self.nb_heads):
|
235 |
+
a_heads.append(hk.Linear(1))
|
236 |
+
|
237 |
+
values = m(z) # [B, N, H*F]
|
238 |
+
values = jnp.reshape(
|
239 |
+
values,
|
240 |
+
values.shape[:-1] + (self.nb_heads, self.head_size)) # [B, N, H, F]
|
241 |
+
values = jnp.transpose(values, (0, 2, 1, 3)) # [B, H, N, F]
|
242 |
+
|
243 |
+
pre_att_1 = w_1(z)
|
244 |
+
pre_att_2 = w_2(z)
|
245 |
+
pre_att_e = w_e(edge_fts)
|
246 |
+
pre_att_g = w_g(graph_fts)
|
247 |
+
|
248 |
+
pre_att = (
|
249 |
+
jnp.expand_dims(pre_att_1, axis=1) + # + [B, 1, N, H*F]
|
250 |
+
jnp.expand_dims(pre_att_2, axis=2) + # + [B, N, 1, H*F]
|
251 |
+
pre_att_e + # + [B, N, N, H*F]
|
252 |
+
jnp.expand_dims(pre_att_g, axis=(1, 2)) # + [B, 1, 1, H*F]
|
253 |
+
) # = [B, N, N, H*F]
|
254 |
+
|
255 |
+
pre_att = jnp.reshape(
|
256 |
+
pre_att,
|
257 |
+
pre_att.shape[:-1] + (self.nb_heads, self.mid_head_size)
|
258 |
+
) # [B, N, N, H, F]
|
259 |
+
|
260 |
+
pre_att = jnp.transpose(pre_att, (0, 3, 1, 2, 4)) # [B, H, N, N, F]
|
261 |
+
|
262 |
+
# This part is not very efficient, but we agree to keep it this way to
|
263 |
+
# enhance readability, assuming `nb_heads` will not be large.
|
264 |
+
logit_heads = []
|
265 |
+
for head in range(self.nb_heads):
|
266 |
+
logit_heads.append(
|
267 |
+
jnp.squeeze(
|
268 |
+
a_heads[head](jax.nn.leaky_relu(pre_att[:, head])),
|
269 |
+
axis=-1)
|
270 |
+
) # [B, N, N]
|
271 |
+
|
272 |
+
logits = jnp.stack(logit_heads, axis=1) # [B, H, N, N]
|
273 |
+
|
274 |
+
coefs = jax.nn.softmax(logits + bias_mat, axis=-1)
|
275 |
+
ret = jnp.matmul(coefs, values) # [B, H, N, F]
|
276 |
+
ret = jnp.transpose(ret, (0, 2, 1, 3)) # [B, N, H, F]
|
277 |
+
ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) # [B, N, H*F]
|
278 |
+
|
279 |
+
if self.residual:
|
280 |
+
ret += skip(z)
|
281 |
+
|
282 |
+
if self.activation is not None:
|
283 |
+
ret = self.activation(ret)
|
284 |
+
|
285 |
+
if self.use_ln:
|
286 |
+
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
|
287 |
+
ret = ln(ret)
|
288 |
+
|
289 |
+
return ret, None # pytype: disable=bad-return-type # numpy-scalars
|
290 |
+
|
291 |
+
|
292 |
+
class GATv2Full(GATv2):
|
293 |
+
"""Graph Attention Network v2 with full adjacency matrix."""
|
294 |
+
|
295 |
+
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
|
296 |
+
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
|
297 |
+
adj_mat = jnp.ones_like(adj_mat)
|
298 |
+
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
|
299 |
+
|
300 |
+
|
301 |
+
def get_triplet_msgs(z, edge_fts, graph_fts, nb_triplet_fts):
|
302 |
+
"""Triplet messages, as done by Dudzik and Velickovic (2022)."""
|
303 |
+
t_1 = hk.Linear(nb_triplet_fts)
|
304 |
+
t_2 = hk.Linear(nb_triplet_fts)
|
305 |
+
t_3 = hk.Linear(nb_triplet_fts)
|
306 |
+
t_e_1 = hk.Linear(nb_triplet_fts)
|
307 |
+
t_e_2 = hk.Linear(nb_triplet_fts)
|
308 |
+
t_e_3 = hk.Linear(nb_triplet_fts)
|
309 |
+
t_g = hk.Linear(nb_triplet_fts)
|
310 |
+
|
311 |
+
tri_1 = t_1(z)
|
312 |
+
tri_2 = t_2(z)
|
313 |
+
tri_3 = t_3(z)
|
314 |
+
tri_e_1 = t_e_1(edge_fts)
|
315 |
+
tri_e_2 = t_e_2(edge_fts)
|
316 |
+
tri_e_3 = t_e_3(edge_fts)
|
317 |
+
tri_g = t_g(graph_fts)
|
318 |
+
|
319 |
+
return (
|
320 |
+
jnp.expand_dims(tri_1, axis=(2, 3)) + # (B, N, 1, 1, H)
|
321 |
+
jnp.expand_dims(tri_2, axis=(1, 3)) + # + (B, 1, N, 1, H)
|
322 |
+
jnp.expand_dims(tri_3, axis=(1, 2)) + # + (B, 1, 1, N, H)
|
323 |
+
jnp.expand_dims(tri_e_1, axis=3) + # + (B, N, N, 1, H)
|
324 |
+
jnp.expand_dims(tri_e_2, axis=2) + # + (B, N, 1, N, H)
|
325 |
+
jnp.expand_dims(tri_e_3, axis=1) + # + (B, 1, N, N, H)
|
326 |
+
jnp.expand_dims(tri_g, axis=(1, 2, 3)) # + (B, 1, 1, 1, H)
|
327 |
+
) # = (B, N, N, N, H)
|
328 |
+
|
329 |
+
|
330 |
+
class PGN(Processor):
|
331 |
+
"""Pointer Graph Networks (Veličković et al., NeurIPS 2020)."""
|
332 |
+
|
333 |
+
def __init__(
|
334 |
+
self,
|
335 |
+
out_size: int,
|
336 |
+
mid_size: Optional[int] = None,
|
337 |
+
mid_act: Optional[_Fn] = None,
|
338 |
+
activation: Optional[_Fn] = jax.nn.relu,
|
339 |
+
reduction: _Fn = jnp.max,
|
340 |
+
msgs_mlp_sizes: Optional[List[int]] = None,
|
341 |
+
use_ln: bool = False,
|
342 |
+
use_triplets: bool = False,
|
343 |
+
nb_triplet_fts: int = 8,
|
344 |
+
gated: bool = False,
|
345 |
+
name: str = 'mpnn_aggr',
|
346 |
+
):
|
347 |
+
super().__init__(name=name)
|
348 |
+
if mid_size is None:
|
349 |
+
self.mid_size = out_size
|
350 |
+
else:
|
351 |
+
self.mid_size = mid_size
|
352 |
+
self.out_size = out_size
|
353 |
+
self.mid_act = mid_act
|
354 |
+
self.activation = activation
|
355 |
+
self.reduction = reduction
|
356 |
+
self._msgs_mlp_sizes = msgs_mlp_sizes
|
357 |
+
self.use_ln = use_ln
|
358 |
+
self.use_triplets = use_triplets
|
359 |
+
self.nb_triplet_fts = nb_triplet_fts
|
360 |
+
self.gated = gated
|
361 |
+
|
362 |
+
def __call__( # pytype: disable=signature-mismatch # numpy-scalars
|
363 |
+
self,
|
364 |
+
node_fts: _Array,
|
365 |
+
edge_fts: _Array,
|
366 |
+
graph_fts: _Array,
|
367 |
+
adj_mat: _Array,
|
368 |
+
hidden: _Array,
|
369 |
+
**unused_kwargs,
|
370 |
+
) -> _Array:
|
371 |
+
"""MPNN inference step."""
|
372 |
+
|
373 |
+
b, n, _ = node_fts.shape
|
374 |
+
assert edge_fts.shape[:-1] == (b, n, n)
|
375 |
+
assert graph_fts.shape[:-1] == (b,)
|
376 |
+
assert adj_mat.shape == (b, n, n)
|
377 |
+
|
378 |
+
z = jnp.concatenate([node_fts, hidden], axis=-1)
|
379 |
+
m_1 = hk.Linear(self.mid_size)
|
380 |
+
m_2 = hk.Linear(self.mid_size)
|
381 |
+
m_e = hk.Linear(self.mid_size)
|
382 |
+
m_g = hk.Linear(self.mid_size)
|
383 |
+
|
384 |
+
o1 = hk.Linear(self.out_size)
|
385 |
+
o2 = hk.Linear(self.out_size)
|
386 |
+
|
387 |
+
msg_1 = m_1(z)
|
388 |
+
msg_2 = m_2(z)
|
389 |
+
msg_e = m_e(edge_fts)
|
390 |
+
msg_g = m_g(graph_fts)
|
391 |
+
|
392 |
+
tri_msgs = None
|
393 |
+
|
394 |
+
if self.use_triplets:
|
395 |
+
# Triplet messages, as done by Dudzik and Velickovic (2022)
|
396 |
+
triplets = get_triplet_msgs(z, edge_fts, graph_fts, self.nb_triplet_fts)
|
397 |
+
|
398 |
+
o3 = hk.Linear(self.out_size)
|
399 |
+
tri_msgs = o3(jnp.max(triplets, axis=1)) # (B, N, N, H)
|
400 |
+
|
401 |
+
if self.activation is not None:
|
402 |
+
tri_msgs = self.activation(tri_msgs)
|
403 |
+
|
404 |
+
msgs = (
|
405 |
+
jnp.expand_dims(msg_1, axis=1) + jnp.expand_dims(msg_2, axis=2) +
|
406 |
+
msg_e + jnp.expand_dims(msg_g, axis=(1, 2)))
|
407 |
+
|
408 |
+
if self._msgs_mlp_sizes is not None:
|
409 |
+
msgs = hk.nets.MLP(self._msgs_mlp_sizes)(jax.nn.relu(msgs))
|
410 |
+
|
411 |
+
if self.mid_act is not None:
|
412 |
+
msgs = self.mid_act(msgs)
|
413 |
+
|
414 |
+
if self.reduction == jnp.mean:
|
415 |
+
msgs = jnp.sum(msgs * jnp.expand_dims(adj_mat, -1), axis=1)
|
416 |
+
msgs = msgs / jnp.sum(adj_mat, axis=-1, keepdims=True)
|
417 |
+
elif self.reduction == jnp.max:
|
418 |
+
maxarg = jnp.where(jnp.expand_dims(adj_mat, -1),
|
419 |
+
msgs,
|
420 |
+
-BIG_NUMBER)
|
421 |
+
msgs = jnp.max(maxarg, axis=1)
|
422 |
+
else:
|
423 |
+
msgs = self.reduction(msgs * jnp.expand_dims(adj_mat, -1), axis=1)
|
424 |
+
|
425 |
+
h_1 = o1(z)
|
426 |
+
h_2 = o2(msgs)
|
427 |
+
|
428 |
+
ret = h_1 + h_2
|
429 |
+
|
430 |
+
if self.activation is not None:
|
431 |
+
ret = self.activation(ret)
|
432 |
+
|
433 |
+
if self.use_ln:
|
434 |
+
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
|
435 |
+
ret = ln(ret)
|
436 |
+
|
437 |
+
if self.gated:
|
438 |
+
gate1 = hk.Linear(self.out_size)
|
439 |
+
gate2 = hk.Linear(self.out_size)
|
440 |
+
gate3 = hk.Linear(self.out_size, b_init=hk.initializers.Constant(-3))
|
441 |
+
gate = jax.nn.sigmoid(gate3(jax.nn.relu(gate1(z) + gate2(msgs))))
|
442 |
+
ret = ret * gate + hidden * (1-gate)
|
443 |
+
|
444 |
+
return ret, tri_msgs # pytype: disable=bad-return-type # numpy-scalars
|
445 |
+
|
446 |
+
|
447 |
+
class DeepSets(PGN):
|
448 |
+
"""Deep Sets (Zaheer et al., NeurIPS 2017)."""
|
449 |
+
|
450 |
+
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
|
451 |
+
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
|
452 |
+
assert adj_mat.ndim == 3
|
453 |
+
adj_mat = jnp.ones_like(adj_mat) * jnp.eye(adj_mat.shape[-1])
|
454 |
+
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
|
455 |
+
|
456 |
+
|
457 |
+
class MPNN(PGN):
|
458 |
+
"""Message-Passing Neural Network (Gilmer et al., ICML 2017)."""
|
459 |
+
|
460 |
+
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
|
461 |
+
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
|
462 |
+
adj_mat = jnp.ones_like(adj_mat)
|
463 |
+
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
|
464 |
+
|
465 |
+
|
466 |
+
class PGNMask(PGN):
|
467 |
+
"""Masked Pointer Graph Networks (Veličković et al., NeurIPS 2020)."""
|
468 |
+
|
469 |
+
@property
|
470 |
+
def inf_bias(self):
|
471 |
+
return True
|
472 |
+
|
473 |
+
@property
|
474 |
+
def inf_bias_edge(self):
|
475 |
+
return True
|
476 |
+
|
477 |
+
|
478 |
+
class MemNetMasked(Processor):
|
479 |
+
"""Implementation of End-to-End Memory Networks.
|
480 |
+
|
481 |
+
Inspired by the description in https://arxiv.org/abs/1503.08895.
|
482 |
+
"""
|
483 |
+
|
484 |
+
def __init__(
|
485 |
+
self,
|
486 |
+
vocab_size: int,
|
487 |
+
sentence_size: int,
|
488 |
+
linear_output_size: int,
|
489 |
+
embedding_size: int = 16,
|
490 |
+
memory_size: Optional[int] = 128,
|
491 |
+
num_hops: int = 1,
|
492 |
+
nonlin: Callable[[Any], Any] = jax.nn.relu,
|
493 |
+
apply_embeddings: bool = True,
|
494 |
+
init_func: hk.initializers.Initializer = jnp.zeros,
|
495 |
+
use_ln: bool = False,
|
496 |
+
name: str = 'memnet') -> None:
|
497 |
+
"""Constructor.
|
498 |
+
|
499 |
+
Args:
|
500 |
+
vocab_size: the number of words in the dictionary (each story, query and
|
501 |
+
answer come contain symbols coming from this dictionary).
|
502 |
+
sentence_size: the dimensionality of each memory.
|
503 |
+
linear_output_size: the dimensionality of the output of the last layer
|
504 |
+
of the model.
|
505 |
+
embedding_size: the dimensionality of the latent space to where all
|
506 |
+
memories are projected.
|
507 |
+
memory_size: the number of memories provided.
|
508 |
+
num_hops: the number of layers in the model.
|
509 |
+
nonlin: non-linear transformation applied at the end of each layer.
|
510 |
+
apply_embeddings: flag whether to aply embeddings.
|
511 |
+
init_func: initialization function for the biases.
|
512 |
+
use_ln: whether to use layer normalisation in the model.
|
513 |
+
name: the name of the model.
|
514 |
+
"""
|
515 |
+
super().__init__(name=name)
|
516 |
+
self._vocab_size = vocab_size
|
517 |
+
self._embedding_size = embedding_size
|
518 |
+
self._sentence_size = sentence_size
|
519 |
+
self._memory_size = memory_size
|
520 |
+
self._linear_output_size = linear_output_size
|
521 |
+
self._num_hops = num_hops
|
522 |
+
self._nonlin = nonlin
|
523 |
+
self._apply_embeddings = apply_embeddings
|
524 |
+
self._init_func = init_func
|
525 |
+
self._use_ln = use_ln
|
526 |
+
# Encoding part: i.e. "I" of the paper.
|
527 |
+
self._encodings = _position_encoding(sentence_size, embedding_size)
|
528 |
+
|
529 |
+
def __call__( # pytype: disable=signature-mismatch # numpy-scalars
|
530 |
+
self,
|
531 |
+
node_fts: _Array,
|
532 |
+
edge_fts: _Array,
|
533 |
+
graph_fts: _Array,
|
534 |
+
adj_mat: _Array,
|
535 |
+
hidden: _Array,
|
536 |
+
**unused_kwargs,
|
537 |
+
) -> _Array:
|
538 |
+
"""MemNet inference step."""
|
539 |
+
|
540 |
+
del hidden
|
541 |
+
node_and_graph_fts = jnp.concatenate([node_fts, graph_fts[:, None]],
|
542 |
+
axis=1)
|
543 |
+
edge_fts_padded = jnp.pad(edge_fts * adj_mat[..., None],
|
544 |
+
((0, 0), (0, 1), (0, 1), (0, 0)))
|
545 |
+
nxt_hidden = jax.vmap(self._apply, (1), 1)(node_and_graph_fts,
|
546 |
+
edge_fts_padded)
|
547 |
+
|
548 |
+
# Broadcast hidden state corresponding to graph features across the nodes.
|
549 |
+
nxt_hidden = nxt_hidden[:, :-1] + nxt_hidden[:, -1:]
|
550 |
+
return nxt_hidden, None # pytype: disable=bad-return-type # numpy-scalars
|
551 |
+
|
552 |
+
def _apply(self, queries: _Array, stories: _Array) -> _Array:
|
553 |
+
"""Apply Memory Network to the queries and stories.
|
554 |
+
|
555 |
+
Args:
|
556 |
+
queries: Tensor of shape [batch_size, sentence_size].
|
557 |
+
stories: Tensor of shape [batch_size, memory_size, sentence_size].
|
558 |
+
|
559 |
+
Returns:
|
560 |
+
Tensor of shape [batch_size, vocab_size].
|
561 |
+
"""
|
562 |
+
if self._apply_embeddings:
|
563 |
+
query_biases = hk.get_parameter(
|
564 |
+
'query_biases',
|
565 |
+
shape=[self._vocab_size - 1, self._embedding_size],
|
566 |
+
init=self._init_func)
|
567 |
+
stories_biases = hk.get_parameter(
|
568 |
+
'stories_biases',
|
569 |
+
shape=[self._vocab_size - 1, self._embedding_size],
|
570 |
+
init=self._init_func)
|
571 |
+
memory_biases = hk.get_parameter(
|
572 |
+
'memory_contents',
|
573 |
+
shape=[self._memory_size, self._embedding_size],
|
574 |
+
init=self._init_func)
|
575 |
+
output_biases = hk.get_parameter(
|
576 |
+
'output_biases',
|
577 |
+
shape=[self._vocab_size - 1, self._embedding_size],
|
578 |
+
init=self._init_func)
|
579 |
+
|
580 |
+
nil_word_slot = jnp.zeros([1, self._embedding_size])
|
581 |
+
|
582 |
+
# This is "A" in the paper.
|
583 |
+
if self._apply_embeddings:
|
584 |
+
stories_biases = jnp.concatenate([stories_biases, nil_word_slot], axis=0)
|
585 |
+
memory_embeddings = jnp.take(
|
586 |
+
stories_biases, stories.reshape([-1]).astype(jnp.int32),
|
587 |
+
axis=0).reshape(list(stories.shape) + [self._embedding_size])
|
588 |
+
memory_embeddings = jnp.pad(
|
589 |
+
memory_embeddings,
|
590 |
+
((0, 0), (0, self._memory_size - jnp.shape(memory_embeddings)[1]),
|
591 |
+
(0, 0), (0, 0)))
|
592 |
+
memory = jnp.sum(memory_embeddings * self._encodings, 2) + memory_biases
|
593 |
+
else:
|
594 |
+
memory = stories
|
595 |
+
|
596 |
+
# This is "B" in the paper. Also, when there are no queries (only
|
597 |
+
# sentences), then there these lines are substituted by
|
598 |
+
# query_embeddings = 0.1.
|
599 |
+
if self._apply_embeddings:
|
600 |
+
query_biases = jnp.concatenate([query_biases, nil_word_slot], axis=0)
|
601 |
+
query_embeddings = jnp.take(
|
602 |
+
query_biases, queries.reshape([-1]).astype(jnp.int32),
|
603 |
+
axis=0).reshape(list(queries.shape) + [self._embedding_size])
|
604 |
+
# This is "u" in the paper.
|
605 |
+
query_input_embedding = jnp.sum(query_embeddings * self._encodings, 1)
|
606 |
+
else:
|
607 |
+
query_input_embedding = queries
|
608 |
+
|
609 |
+
# This is "C" in the paper.
|
610 |
+
if self._apply_embeddings:
|
611 |
+
output_biases = jnp.concatenate([output_biases, nil_word_slot], axis=0)
|
612 |
+
output_embeddings = jnp.take(
|
613 |
+
output_biases, stories.reshape([-1]).astype(jnp.int32),
|
614 |
+
axis=0).reshape(list(stories.shape) + [self._embedding_size])
|
615 |
+
output_embeddings = jnp.pad(
|
616 |
+
output_embeddings,
|
617 |
+
((0, 0), (0, self._memory_size - jnp.shape(output_embeddings)[1]),
|
618 |
+
(0, 0), (0, 0)))
|
619 |
+
output = jnp.sum(output_embeddings * self._encodings, 2)
|
620 |
+
else:
|
621 |
+
output = stories
|
622 |
+
|
623 |
+
intermediate_linear = hk.Linear(self._embedding_size, with_bias=False)
|
624 |
+
|
625 |
+
# Output_linear is "H".
|
626 |
+
output_linear = hk.Linear(self._linear_output_size, with_bias=False)
|
627 |
+
|
628 |
+
for hop_number in range(self._num_hops):
|
629 |
+
query_input_embedding_transposed = jnp.transpose(
|
630 |
+
jnp.expand_dims(query_input_embedding, -1), [0, 2, 1])
|
631 |
+
|
632 |
+
# Calculate probabilities.
|
633 |
+
probs = jax.nn.softmax(
|
634 |
+
jnp.sum(memory * query_input_embedding_transposed, 2))
|
635 |
+
|
636 |
+
# Calculate output of the layer by multiplying by C.
|
637 |
+
transposed_probs = jnp.transpose(jnp.expand_dims(probs, -1), [0, 2, 1])
|
638 |
+
transposed_output_embeddings = jnp.transpose(output, [0, 2, 1])
|
639 |
+
|
640 |
+
# This is "o" in the paper.
|
641 |
+
layer_output = jnp.sum(transposed_output_embeddings * transposed_probs, 2)
|
642 |
+
|
643 |
+
# Finally the answer
|
644 |
+
if hop_number == self._num_hops - 1:
|
645 |
+
# Please note that in the TF version we apply the final linear layer
|
646 |
+
# in all hops and this results in shape mismatches.
|
647 |
+
output_layer = output_linear(query_input_embedding + layer_output)
|
648 |
+
else:
|
649 |
+
output_layer = intermediate_linear(query_input_embedding + layer_output)
|
650 |
+
|
651 |
+
query_input_embedding = output_layer
|
652 |
+
if self._nonlin:
|
653 |
+
output_layer = self._nonlin(output_layer)
|
654 |
+
|
655 |
+
# This linear here is "W".
|
656 |
+
ret = hk.Linear(self._vocab_size, with_bias=False)(output_layer)
|
657 |
+
|
658 |
+
if self._use_ln:
|
659 |
+
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True)
|
660 |
+
ret = ln(ret)
|
661 |
+
|
662 |
+
return ret
|
663 |
+
|
664 |
+
|
665 |
+
class MemNetFull(MemNetMasked):
|
666 |
+
"""Memory Networks with full adjacency matrix."""
|
667 |
+
|
668 |
+
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array,
|
669 |
+
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array:
|
670 |
+
adj_mat = jnp.ones_like(adj_mat)
|
671 |
+
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden)
|
672 |
+
|
673 |
+
|
674 |
+
ProcessorFactory = Callable[[int], Processor]
|
675 |
+
|
676 |
+
|
677 |
+
def get_processor_factory(kind: str,
|
678 |
+
use_ln: bool,
|
679 |
+
nb_triplet_fts: int,
|
680 |
+
nb_heads: Optional[int] = None) -> ProcessorFactory:
|
681 |
+
"""Returns a processor factory.
|
682 |
+
|
683 |
+
Args:
|
684 |
+
kind: One of the available types of processor.
|
685 |
+
use_ln: Whether the processor passes the output through a layernorm layer.
|
686 |
+
nb_triplet_fts: How many triplet features to compute.
|
687 |
+
nb_heads: Number of attention heads for GAT processors.
|
688 |
+
Returns:
|
689 |
+
A callable that takes an `out_size` parameter (equal to the hidden
|
690 |
+
dimension of the network) and returns a processor instance.
|
691 |
+
"""
|
692 |
+
def _factory(out_size: int):
|
693 |
+
if kind == 'deepsets':
|
694 |
+
processor = DeepSets(
|
695 |
+
out_size=out_size,
|
696 |
+
msgs_mlp_sizes=[out_size, out_size],
|
697 |
+
use_ln=use_ln,
|
698 |
+
use_triplets=False,
|
699 |
+
nb_triplet_fts=0
|
700 |
+
)
|
701 |
+
elif kind == 'gat':
|
702 |
+
processor = GAT(
|
703 |
+
out_size=out_size,
|
704 |
+
nb_heads=nb_heads,
|
705 |
+
use_ln=use_ln,
|
706 |
+
)
|
707 |
+
elif kind == 'gat_full':
|
708 |
+
processor = GATFull(
|
709 |
+
out_size=out_size,
|
710 |
+
nb_heads=nb_heads,
|
711 |
+
use_ln=use_ln
|
712 |
+
)
|
713 |
+
elif kind == 'gatv2':
|
714 |
+
processor = GATv2(
|
715 |
+
out_size=out_size,
|
716 |
+
nb_heads=nb_heads,
|
717 |
+
use_ln=use_ln
|
718 |
+
)
|
719 |
+
elif kind == 'gatv2_full':
|
720 |
+
processor = GATv2Full(
|
721 |
+
out_size=out_size,
|
722 |
+
nb_heads=nb_heads,
|
723 |
+
use_ln=use_ln
|
724 |
+
)
|
725 |
+
elif kind == 'memnet_full':
|
726 |
+
processor = MemNetFull(
|
727 |
+
vocab_size=out_size,
|
728 |
+
sentence_size=out_size,
|
729 |
+
linear_output_size=out_size,
|
730 |
+
)
|
731 |
+
elif kind == 'memnet_masked':
|
732 |
+
processor = MemNetMasked(
|
733 |
+
vocab_size=out_size,
|
734 |
+
sentence_size=out_size,
|
735 |
+
linear_output_size=out_size,
|
736 |
+
)
|
737 |
+
elif kind == 'mpnn':
|
738 |
+
processor = MPNN(
|
739 |
+
out_size=out_size,
|
740 |
+
msgs_mlp_sizes=[out_size, out_size],
|
741 |
+
use_ln=use_ln,
|
742 |
+
use_triplets=False,
|
743 |
+
nb_triplet_fts=0,
|
744 |
+
)
|
745 |
+
elif kind == 'pgn':
|
746 |
+
processor = PGN(
|
747 |
+
out_size=out_size,
|
748 |
+
msgs_mlp_sizes=[out_size, out_size],
|
749 |
+
use_ln=use_ln,
|
750 |
+
use_triplets=False,
|
751 |
+
nb_triplet_fts=0,
|
752 |
+
)
|
753 |
+
elif kind == 'pgn_mask':
|
754 |
+
processor = PGNMask(
|
755 |
+
out_size=out_size,
|
756 |
+
msgs_mlp_sizes=[out_size, out_size],
|
757 |
+
use_ln=use_ln,
|
758 |
+
use_triplets=False,
|
759 |
+
nb_triplet_fts=0,
|
760 |
+
)
|
761 |
+
elif kind == 'triplet_mpnn':
|
762 |
+
processor = MPNN(
|
763 |
+
out_size=out_size,
|
764 |
+
msgs_mlp_sizes=[out_size, out_size],
|
765 |
+
use_ln=use_ln,
|
766 |
+
use_triplets=True,
|
767 |
+
nb_triplet_fts=nb_triplet_fts,
|
768 |
+
)
|
769 |
+
elif kind == 'triplet_pgn':
|
770 |
+
processor = PGN(
|
771 |
+
out_size=out_size,
|
772 |
+
msgs_mlp_sizes=[out_size, out_size],
|
773 |
+
use_ln=use_ln,
|
774 |
+
use_triplets=True,
|
775 |
+
nb_triplet_fts=nb_triplet_fts,
|
776 |
+
)
|
777 |
+
elif kind == 'triplet_pgn_mask':
|
778 |
+
processor = PGNMask(
|
779 |
+
out_size=out_size,
|
780 |
+
msgs_mlp_sizes=[out_size, out_size],
|
781 |
+
use_ln=use_ln,
|
782 |
+
use_triplets=True,
|
783 |
+
nb_triplet_fts=nb_triplet_fts,
|
784 |
+
)
|
785 |
+
elif kind == 'gpgn':
|
786 |
+
processor = PGN(
|
787 |
+
out_size=out_size,
|
788 |
+
msgs_mlp_sizes=[out_size, out_size],
|
789 |
+
use_ln=use_ln,
|
790 |
+
use_triplets=False,
|
791 |
+
nb_triplet_fts=nb_triplet_fts,
|
792 |
+
gated=True,
|
793 |
+
)
|
794 |
+
elif kind == 'gpgn_mask':
|
795 |
+
processor = PGNMask(
|
796 |
+
out_size=out_size,
|
797 |
+
msgs_mlp_sizes=[out_size, out_size],
|
798 |
+
use_ln=use_ln,
|
799 |
+
use_triplets=False,
|
800 |
+
nb_triplet_fts=nb_triplet_fts,
|
801 |
+
gated=True,
|
802 |
+
)
|
803 |
+
elif kind == 'gmpnn':
|
804 |
+
processor = MPNN(
|
805 |
+
out_size=out_size,
|
806 |
+
msgs_mlp_sizes=[out_size, out_size],
|
807 |
+
use_ln=use_ln,
|
808 |
+
use_triplets=False,
|
809 |
+
nb_triplet_fts=nb_triplet_fts,
|
810 |
+
gated=True,
|
811 |
+
)
|
812 |
+
elif kind == 'triplet_gpgn':
|
813 |
+
processor = PGN(
|
814 |
+
out_size=out_size,
|
815 |
+
msgs_mlp_sizes=[out_size, out_size],
|
816 |
+
use_ln=use_ln,
|
817 |
+
use_triplets=True,
|
818 |
+
nb_triplet_fts=nb_triplet_fts,
|
819 |
+
gated=True,
|
820 |
+
)
|
821 |
+
elif kind == 'triplet_gpgn_mask':
|
822 |
+
processor = PGNMask(
|
823 |
+
out_size=out_size,
|
824 |
+
msgs_mlp_sizes=[out_size, out_size],
|
825 |
+
use_ln=use_ln,
|
826 |
+
use_triplets=True,
|
827 |
+
nb_triplet_fts=nb_triplet_fts,
|
828 |
+
gated=True,
|
829 |
+
)
|
830 |
+
elif kind == 'triplet_gmpnn':
|
831 |
+
processor = MPNN(
|
832 |
+
out_size=out_size,
|
833 |
+
msgs_mlp_sizes=[out_size, out_size],
|
834 |
+
use_ln=use_ln,
|
835 |
+
use_triplets=True,
|
836 |
+
nb_triplet_fts=nb_triplet_fts,
|
837 |
+
gated=True,
|
838 |
+
)
|
839 |
+
else:
|
840 |
+
raise ValueError('Unexpected processor kind ' + kind)
|
841 |
+
|
842 |
+
return processor
|
843 |
+
|
844 |
+
return _factory
|
845 |
+
|
846 |
+
|
847 |
+
def _position_encoding(sentence_size: int, embedding_size: int) -> np.ndarray:
|
848 |
+
"""Position Encoding described in section 4.1 [1]."""
|
849 |
+
encoding = np.ones((embedding_size, sentence_size), dtype=np.float32)
|
850 |
+
ls = sentence_size + 1
|
851 |
+
le = embedding_size + 1
|
852 |
+
for i in range(1, le):
|
853 |
+
for j in range(1, ls):
|
854 |
+
encoding[i - 1, j - 1] = (i - (le - 1) / 2) * (j - (ls - 1) / 2)
|
855 |
+
encoding = 1 + 4 * encoding / embedding_size / sentence_size
|
856 |
+
return np.transpose(encoding)
|
benchmarks/CLRS/env/processors_test.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Tests for processors.py."""
|
17 |
+
|
18 |
+
from absl.testing import absltest
|
19 |
+
import chex
|
20 |
+
from clrs._src import processors
|
21 |
+
import haiku as hk
|
22 |
+
import jax.numpy as jnp
|
23 |
+
|
24 |
+
|
25 |
+
class MemnetTest(absltest.TestCase):
|
26 |
+
|
27 |
+
def test_simple_run_and_check_shapes(self):
|
28 |
+
|
29 |
+
batch_size = 64
|
30 |
+
vocab_size = 177
|
31 |
+
embedding_size = 64
|
32 |
+
sentence_size = 11
|
33 |
+
memory_size = 320
|
34 |
+
linear_output_size = 128
|
35 |
+
num_hops = 2
|
36 |
+
use_ln = True
|
37 |
+
|
38 |
+
def forward_fn(queries, stories):
|
39 |
+
model = processors.MemNetFull(
|
40 |
+
vocab_size=vocab_size,
|
41 |
+
embedding_size=embedding_size,
|
42 |
+
sentence_size=sentence_size,
|
43 |
+
memory_size=memory_size,
|
44 |
+
linear_output_size=linear_output_size,
|
45 |
+
num_hops=num_hops,
|
46 |
+
use_ln=use_ln)
|
47 |
+
return model._apply(queries, stories)
|
48 |
+
|
49 |
+
forward = hk.transform(forward_fn)
|
50 |
+
|
51 |
+
queries = jnp.ones([batch_size, sentence_size], dtype=jnp.int32)
|
52 |
+
stories = jnp.ones([batch_size, memory_size, sentence_size],
|
53 |
+
dtype=jnp.int32)
|
54 |
+
|
55 |
+
key = hk.PRNGSequence(42)
|
56 |
+
params = forward.init(next(key), queries, stories)
|
57 |
+
|
58 |
+
model_output = forward.apply(params, None, queries, stories)
|
59 |
+
chex.assert_shape(model_output, [batch_size, vocab_size])
|
60 |
+
chex.assert_type(model_output, jnp.float32)
|
61 |
+
|
62 |
+
|
63 |
+
if __name__ == '__main__':
|
64 |
+
absltest.main()
|
benchmarks/CLRS/env/samplers.py
ADDED
@@ -0,0 +1,882 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Sampling utilities."""
|
17 |
+
|
18 |
+
import abc
|
19 |
+
import collections
|
20 |
+
import inspect
|
21 |
+
import types
|
22 |
+
|
23 |
+
from typing import Any, Callable, List, Optional, Tuple
|
24 |
+
from absl import logging
|
25 |
+
|
26 |
+
from clrs._src import algorithms
|
27 |
+
from clrs._src import probing
|
28 |
+
from clrs._src import specs
|
29 |
+
import jax
|
30 |
+
import numpy as np
|
31 |
+
|
32 |
+
|
33 |
+
_Array = np.ndarray
|
34 |
+
_DataPoint = probing.DataPoint
|
35 |
+
Trajectory = List[_DataPoint]
|
36 |
+
Trajectories = List[Trajectory]
|
37 |
+
|
38 |
+
|
39 |
+
Algorithm = Callable[..., Any]
|
40 |
+
Features = collections.namedtuple('Features', ['inputs', 'hints', 'lengths'])
|
41 |
+
FeaturesChunked = collections.namedtuple(
|
42 |
+
'Features', ['inputs', 'hints', 'is_first', 'is_last'])
|
43 |
+
Feedback = collections.namedtuple('Feedback', ['features', 'outputs'])
|
44 |
+
|
45 |
+
# CLRS-30 baseline spec.
|
46 |
+
CLRS30 = types.MappingProxyType({
|
47 |
+
'train': {
|
48 |
+
'num_samples': 1000,
|
49 |
+
'length': 16,
|
50 |
+
'seed': 1,
|
51 |
+
},
|
52 |
+
'val': {
|
53 |
+
'num_samples': 32,
|
54 |
+
'length': 16,
|
55 |
+
'seed': 2,
|
56 |
+
},
|
57 |
+
'test': {
|
58 |
+
'num_samples': 32,
|
59 |
+
'length': 64,
|
60 |
+
'seed': 3,
|
61 |
+
},
|
62 |
+
})
|
63 |
+
|
64 |
+
|
65 |
+
class Sampler(abc.ABC):
|
66 |
+
"""Sampler abstract base class."""
|
67 |
+
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
algorithm: Algorithm,
|
71 |
+
spec: specs.Spec,
|
72 |
+
num_samples: int,
|
73 |
+
*args,
|
74 |
+
seed: Optional[int] = None,
|
75 |
+
**kwargs,
|
76 |
+
):
|
77 |
+
"""Initializes a `Sampler`.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
algorithm: The algorithm to sample from
|
81 |
+
spec: The algorithm spec.
|
82 |
+
num_samples: Number of algorithm unrolls to sample. If positive, all the
|
83 |
+
samples will be generated in the constructor, and at each call
|
84 |
+
of the `next` method a batch will be randomly selected among them.
|
85 |
+
If -1, samples are generated on the fly with each call to `next`.
|
86 |
+
*args: Algorithm args.
|
87 |
+
seed: RNG seed.
|
88 |
+
**kwargs: Algorithm kwargs.
|
89 |
+
"""
|
90 |
+
|
91 |
+
# Use `RandomState` to ensure deterministic sampling across Numpy versions.
|
92 |
+
self._rng = np.random.RandomState(seed)
|
93 |
+
self._spec = spec
|
94 |
+
self._num_samples = num_samples
|
95 |
+
self._algorithm = algorithm
|
96 |
+
self._args = args
|
97 |
+
self._kwargs = kwargs
|
98 |
+
|
99 |
+
if num_samples < 0:
|
100 |
+
logging.warning('Sampling dataset on-the-fly, unlimited samples.')
|
101 |
+
# Just get an initial estimate of max hint length
|
102 |
+
self.max_steps = -1
|
103 |
+
for _ in range(1000):
|
104 |
+
data = self._sample_data(*args, **kwargs)
|
105 |
+
_, probes = algorithm(*data)
|
106 |
+
_, _, hint = probing.split_stages(probes, spec)
|
107 |
+
for dp in hint:
|
108 |
+
assert dp.data.shape[1] == 1 # batching axis
|
109 |
+
if dp.data.shape[0] > self.max_steps:
|
110 |
+
self.max_steps = dp.data.shape[0]
|
111 |
+
else:
|
112 |
+
logging.info('Creating a dataset with %i samples.', num_samples)
|
113 |
+
(self._inputs, self._outputs, self._hints,
|
114 |
+
self._lengths) = self._make_batch(num_samples, spec, 0, algorithm, *args,
|
115 |
+
**kwargs)
|
116 |
+
|
117 |
+
def _make_batch(self, num_samples: int, spec: specs.Spec, min_length: int,
|
118 |
+
algorithm: Algorithm, *args, **kwargs):
|
119 |
+
"""Generate a batch of data."""
|
120 |
+
inputs = []
|
121 |
+
outputs = []
|
122 |
+
hints = []
|
123 |
+
|
124 |
+
for _ in range(num_samples):
|
125 |
+
data = self._sample_data(*args, **kwargs)
|
126 |
+
_, probes = algorithm(*data)
|
127 |
+
inp, outp, hint = probing.split_stages(probes, spec)
|
128 |
+
inputs.append(inp)
|
129 |
+
outputs.append(outp)
|
130 |
+
hints.append(hint)
|
131 |
+
if len(hints) % 1000 == 0:
|
132 |
+
logging.info('%i samples created', len(hints))
|
133 |
+
|
134 |
+
# Batch and pad trajectories to max(T).
|
135 |
+
inputs = _batch_io(inputs)
|
136 |
+
outputs = _batch_io(outputs)
|
137 |
+
hints, lengths = _batch_hints(hints, min_length)
|
138 |
+
return inputs, outputs, hints, lengths
|
139 |
+
|
140 |
+
def next(self, batch_size: Optional[int] = None) -> Feedback:
|
141 |
+
"""Subsamples trajectories from the pre-generated dataset.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
batch_size: Optional batch size. If `None`, returns entire dataset.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
Subsampled trajectories.
|
148 |
+
"""
|
149 |
+
if batch_size:
|
150 |
+
if self._num_samples < 0: # generate on the fly
|
151 |
+
inputs, outputs, hints, lengths = self._make_batch(
|
152 |
+
batch_size, self._spec, self.max_steps,
|
153 |
+
self._algorithm, *self._args, **self._kwargs)
|
154 |
+
if hints[0].data.shape[0] > self.max_steps:
|
155 |
+
logging.warning('Increasing hint lengh from %i to %i',
|
156 |
+
self.max_steps, hints[0].data.shape[0])
|
157 |
+
self.max_steps = hints[0].data.shape[0]
|
158 |
+
else:
|
159 |
+
if batch_size > self._num_samples:
|
160 |
+
raise ValueError(
|
161 |
+
f'Batch size {batch_size} > dataset size {self._num_samples}.')
|
162 |
+
|
163 |
+
# Returns a fixed-size random batch.
|
164 |
+
indices = self._rng.choice(self._num_samples, (batch_size,),
|
165 |
+
replace=True)
|
166 |
+
inputs = _subsample_data(self._inputs, indices, axis=0)
|
167 |
+
outputs = _subsample_data(self._outputs, indices, axis=0)
|
168 |
+
hints = _subsample_data(self._hints, indices, axis=1)
|
169 |
+
lengths = self._lengths[indices]
|
170 |
+
|
171 |
+
else:
|
172 |
+
# Returns the full dataset.
|
173 |
+
assert self._num_samples >= 0
|
174 |
+
inputs = self._inputs
|
175 |
+
hints = self._hints
|
176 |
+
lengths = self._lengths
|
177 |
+
outputs = self._outputs
|
178 |
+
|
179 |
+
return Feedback(Features(inputs, hints, lengths), outputs)
|
180 |
+
|
181 |
+
@abc.abstractmethod
|
182 |
+
def _sample_data(self, length: int, *args, **kwargs) -> List[_Array]:
|
183 |
+
pass
|
184 |
+
|
185 |
+
def _random_sequence(self, length, low=0.0, high=1.0):
|
186 |
+
"""Random sequence."""
|
187 |
+
return self._rng.uniform(low=low, high=high, size=(length,))
|
188 |
+
|
189 |
+
def _random_string(self, length, chars=4):
|
190 |
+
"""Random string."""
|
191 |
+
return self._rng.randint(0, high=chars, size=(length,))
|
192 |
+
|
193 |
+
def _random_er_graph(self, nb_nodes, p=0.5, directed=False, acyclic=False,
|
194 |
+
weighted=False, low=0.0, high=1.0):
|
195 |
+
"""Random Erdos-Renyi graph."""
|
196 |
+
|
197 |
+
mat = self._rng.binomial(1, p, size=(nb_nodes, nb_nodes))
|
198 |
+
if not directed:
|
199 |
+
mat *= np.transpose(mat)
|
200 |
+
elif acyclic:
|
201 |
+
mat = np.triu(mat, k=1)
|
202 |
+
p = self._rng.permutation(nb_nodes) # To allow nontrivial solutions
|
203 |
+
mat = mat[p, :][:, p]
|
204 |
+
if weighted:
|
205 |
+
weights = self._rng.uniform(low=low, high=high, size=(nb_nodes, nb_nodes))
|
206 |
+
if not directed:
|
207 |
+
weights *= np.transpose(weights)
|
208 |
+
weights = np.sqrt(weights + 1e-3) # Add epsilon to protect underflow
|
209 |
+
mat = mat.astype(float) * weights
|
210 |
+
return mat
|
211 |
+
|
212 |
+
def _random_community_graph(self, nb_nodes, k=4, p=0.5, eps=0.01,
|
213 |
+
directed=False, acyclic=False, weighted=False,
|
214 |
+
low=0.0, high=1.0):
|
215 |
+
"""Random perturbed k-community graph."""
|
216 |
+
mat = np.zeros((nb_nodes, nb_nodes))
|
217 |
+
if k > nb_nodes:
|
218 |
+
raise ValueError(f'Cannot generate graph of too many ({k}) communities.')
|
219 |
+
los, his = [], []
|
220 |
+
lo = 0
|
221 |
+
for i in range(k):
|
222 |
+
if i == k - 1:
|
223 |
+
hi = nb_nodes
|
224 |
+
else:
|
225 |
+
hi = lo + nb_nodes // k
|
226 |
+
mat[lo:hi, lo:hi] = self._random_er_graph(
|
227 |
+
hi - lo, p=p, directed=directed,
|
228 |
+
acyclic=acyclic, weighted=weighted,
|
229 |
+
low=low, high=high)
|
230 |
+
los.append(lo)
|
231 |
+
his.append(hi)
|
232 |
+
lo = hi
|
233 |
+
toggle = self._random_er_graph(nb_nodes, p=eps, directed=directed,
|
234 |
+
acyclic=acyclic, weighted=weighted,
|
235 |
+
low=low, high=high)
|
236 |
+
|
237 |
+
# Prohibit closing new cycles
|
238 |
+
for i in range(k):
|
239 |
+
for j in range(i):
|
240 |
+
toggle[los[i]:his[i], los[j]:his[j]] *= 0
|
241 |
+
|
242 |
+
mat = np.where(toggle > 0.0, (1.0 - (mat > 0.0)) * toggle, mat)
|
243 |
+
p = self._rng.permutation(nb_nodes) # To allow nontrivial solutions
|
244 |
+
mat = mat[p, :][:, p]
|
245 |
+
return mat
|
246 |
+
|
247 |
+
def _random_bipartite_graph(self, n, m, p=0.25):
|
248 |
+
"""Random bipartite graph-based flow network."""
|
249 |
+
nb_nodes = n + m + 2
|
250 |
+
s = 0
|
251 |
+
t = n + m + 1
|
252 |
+
mat = np.zeros((nb_nodes, nb_nodes))
|
253 |
+
mat[s, 1:n+1] = 1.0 # supersource
|
254 |
+
mat[n+1:n+m+1, t] = 1.0 # supersink
|
255 |
+
mat[1:n+1, n+1:n+m+1] = self._rng.binomial(1, p, size=(n, m))
|
256 |
+
return mat
|
257 |
+
|
258 |
+
|
259 |
+
def build_sampler(
|
260 |
+
name: str,
|
261 |
+
num_samples: int,
|
262 |
+
*args,
|
263 |
+
seed: Optional[int] = None,
|
264 |
+
**kwargs,
|
265 |
+
) -> Tuple[Sampler, specs.Spec]:
|
266 |
+
"""Builds a sampler. See `Sampler` documentation."""
|
267 |
+
|
268 |
+
if name not in specs.SPECS or name not in SAMPLERS:
|
269 |
+
raise NotImplementedError(f'No implementation of algorithm {name}.')
|
270 |
+
spec = specs.SPECS[name]
|
271 |
+
algorithm = getattr(algorithms, name)
|
272 |
+
sampler_class = SAMPLERS[name]
|
273 |
+
# Ignore kwargs not accepted by the sampler.
|
274 |
+
sampler_args = inspect.signature(sampler_class._sample_data).parameters # pylint:disable=protected-access
|
275 |
+
clean_kwargs = {k: kwargs[k] for k in kwargs if k in sampler_args}
|
276 |
+
if set(clean_kwargs) != set(kwargs):
|
277 |
+
logging.warning('Ignoring kwargs %s when building sampler class %s',
|
278 |
+
set(kwargs).difference(clean_kwargs), sampler_class)
|
279 |
+
sampler = sampler_class(algorithm, spec, num_samples, seed=seed,
|
280 |
+
*args, **clean_kwargs)
|
281 |
+
return sampler, spec
|
282 |
+
|
283 |
+
|
284 |
+
class SortingSampler(Sampler):
|
285 |
+
"""Sorting sampler. Generates a random sequence of U[0, 1]."""
|
286 |
+
|
287 |
+
def _sample_data(
|
288 |
+
self,
|
289 |
+
length: int,
|
290 |
+
low: float = 0.,
|
291 |
+
high: float = 1.,
|
292 |
+
):
|
293 |
+
arr = self._random_sequence(length=length, low=low, high=high)
|
294 |
+
return [arr]
|
295 |
+
|
296 |
+
|
297 |
+
class SearchSampler(Sampler):
|
298 |
+
"""Search sampler. Generates a random sequence and target (of U[0, 1])."""
|
299 |
+
|
300 |
+
def _sample_data(
|
301 |
+
self,
|
302 |
+
length: int,
|
303 |
+
low: float = 0.,
|
304 |
+
high: float = 1.,
|
305 |
+
):
|
306 |
+
arr = self._random_sequence(length=length, low=low, high=high)
|
307 |
+
arr.sort()
|
308 |
+
x = self._rng.uniform(low=low, high=high)
|
309 |
+
return [x, arr]
|
310 |
+
|
311 |
+
|
312 |
+
class MaxSubarraySampler(Sampler):
|
313 |
+
"""Maximum subarray sampler. Generates a random sequence of U[-1, 1]."""
|
314 |
+
|
315 |
+
def _sample_data(
|
316 |
+
self,
|
317 |
+
length: int,
|
318 |
+
low: float = -1.,
|
319 |
+
high: float = 1.,
|
320 |
+
):
|
321 |
+
arr = self._random_sequence(length=length, low=low, high=high)
|
322 |
+
return [arr]
|
323 |
+
|
324 |
+
|
325 |
+
class LCSSampler(Sampler):
|
326 |
+
"""Longest Common Subsequence sampler. Generates two random ATCG strings."""
|
327 |
+
|
328 |
+
def _sample_data(
|
329 |
+
self,
|
330 |
+
length: int,
|
331 |
+
length_2: Optional[int] = None,
|
332 |
+
chars: int = 4,
|
333 |
+
):
|
334 |
+
if length_2 is None:
|
335 |
+
# Assume provided length is total length.
|
336 |
+
length_2 = length // 2
|
337 |
+
length -= length_2
|
338 |
+
a = self._random_string(length=length, chars=chars)
|
339 |
+
b = self._random_string(length=length_2, chars=chars)
|
340 |
+
return [a, b]
|
341 |
+
|
342 |
+
|
343 |
+
class OptimalBSTSampler(Sampler):
|
344 |
+
"""Optimal BST sampler. Samples array of probabilities, splits it into two."""
|
345 |
+
|
346 |
+
def _sample_data(
|
347 |
+
self,
|
348 |
+
length: int,
|
349 |
+
):
|
350 |
+
tot_length = length + (length + 1)
|
351 |
+
arr = self._random_sequence(length=tot_length, low=0.0, high=1.0)
|
352 |
+
arr /= np.sum(arr)
|
353 |
+
p = arr[:length]
|
354 |
+
q = arr[length:]
|
355 |
+
return [p, q]
|
356 |
+
|
357 |
+
|
358 |
+
class ActivitySampler(Sampler):
|
359 |
+
"""Activity sampler. Samples start and finish times from U[0, 1]."""
|
360 |
+
|
361 |
+
def _sample_data(
|
362 |
+
self,
|
363 |
+
length: int,
|
364 |
+
low: float = 0.,
|
365 |
+
high: float = 1.,
|
366 |
+
):
|
367 |
+
arr_1 = self._random_sequence(length=length, low=low, high=high)
|
368 |
+
arr_2 = self._random_sequence(length=length, low=low, high=high)
|
369 |
+
return [np.minimum(arr_1, arr_2), np.maximum(arr_1, arr_2)]
|
370 |
+
|
371 |
+
|
372 |
+
class TaskSampler(Sampler):
|
373 |
+
"""Task sampler. Samples deadlines (integers) and values (U[0, 1])."""
|
374 |
+
|
375 |
+
def _sample_data(
|
376 |
+
self,
|
377 |
+
length: int,
|
378 |
+
max_deadline: Optional[int] = None,
|
379 |
+
low: float = 0.,
|
380 |
+
high: float = 1.,
|
381 |
+
):
|
382 |
+
if max_deadline is None:
|
383 |
+
max_deadline = length
|
384 |
+
d = self._random_string(length=length, chars=max_deadline) + 1
|
385 |
+
w = self._random_sequence(length=length, low=low, high=high)
|
386 |
+
return [d, w]
|
387 |
+
|
388 |
+
|
389 |
+
class DfsSampler(Sampler):
|
390 |
+
"""DFS sampler."""
|
391 |
+
|
392 |
+
def _sample_data(
|
393 |
+
self,
|
394 |
+
length: int,
|
395 |
+
p: Tuple[float, ...] = (0.5,),
|
396 |
+
):
|
397 |
+
graph = self._random_er_graph(
|
398 |
+
nb_nodes=length, p=self._rng.choice(p),
|
399 |
+
directed=True, acyclic=False, weighted=False)
|
400 |
+
return [graph]
|
401 |
+
|
402 |
+
|
403 |
+
class BfsSampler(Sampler):
|
404 |
+
"""BFS sampler."""
|
405 |
+
|
406 |
+
def _sample_data(
|
407 |
+
self,
|
408 |
+
length: int,
|
409 |
+
p: Tuple[float, ...] = (0.5,),
|
410 |
+
):
|
411 |
+
graph = self._random_er_graph(
|
412 |
+
nb_nodes=length, p=self._rng.choice(p),
|
413 |
+
directed=False, acyclic=False, weighted=False)
|
414 |
+
source_node = self._rng.choice(length)
|
415 |
+
return [graph, source_node]
|
416 |
+
|
417 |
+
|
418 |
+
class TopoSampler(Sampler):
|
419 |
+
"""Topological Sorting sampler."""
|
420 |
+
|
421 |
+
def _sample_data(
|
422 |
+
self,
|
423 |
+
length: int,
|
424 |
+
p: Tuple[float, ...] = (0.5,),
|
425 |
+
):
|
426 |
+
graph = self._random_er_graph(
|
427 |
+
nb_nodes=length, p=self._rng.choice(p),
|
428 |
+
directed=True, acyclic=True, weighted=False)
|
429 |
+
return [graph]
|
430 |
+
|
431 |
+
|
432 |
+
class ArticulationSampler(Sampler):
|
433 |
+
"""Articulation Point sampler."""
|
434 |
+
|
435 |
+
def _sample_data(
|
436 |
+
self,
|
437 |
+
length: int,
|
438 |
+
p: Tuple[float, ...] = (0.2,),
|
439 |
+
):
|
440 |
+
graph = self._random_er_graph(
|
441 |
+
nb_nodes=length, p=self._rng.choice(p), directed=False,
|
442 |
+
acyclic=False, weighted=False)
|
443 |
+
return [graph]
|
444 |
+
|
445 |
+
|
446 |
+
class MSTSampler(Sampler):
|
447 |
+
"""MST sampler for Kruskal's algorithm."""
|
448 |
+
|
449 |
+
def _sample_data(
|
450 |
+
self,
|
451 |
+
length: int,
|
452 |
+
p: Tuple[float, ...] = (0.2,), # lower p to account for class imbalance
|
453 |
+
low: float = 0.,
|
454 |
+
high: float = 1.,
|
455 |
+
):
|
456 |
+
graph = self._random_er_graph(
|
457 |
+
nb_nodes=length,
|
458 |
+
p=self._rng.choice(p),
|
459 |
+
directed=False,
|
460 |
+
acyclic=False,
|
461 |
+
weighted=True,
|
462 |
+
low=low,
|
463 |
+
high=high)
|
464 |
+
return [graph]
|
465 |
+
|
466 |
+
|
467 |
+
class BellmanFordSampler(Sampler):
|
468 |
+
"""Bellman-Ford sampler."""
|
469 |
+
|
470 |
+
def _sample_data(
|
471 |
+
self,
|
472 |
+
length: int,
|
473 |
+
p: Tuple[float, ...] = (0.5,),
|
474 |
+
low: float = 0.,
|
475 |
+
high: float = 1.,
|
476 |
+
):
|
477 |
+
graph = self._random_er_graph(
|
478 |
+
nb_nodes=length,
|
479 |
+
p=self._rng.choice(p),
|
480 |
+
directed=False,
|
481 |
+
acyclic=False,
|
482 |
+
weighted=True,
|
483 |
+
low=low,
|
484 |
+
high=high)
|
485 |
+
source_node = self._rng.choice(length)
|
486 |
+
return [graph, source_node]
|
487 |
+
|
488 |
+
|
489 |
+
class DAGPathSampler(Sampler):
|
490 |
+
"""Sampler for DAG shortest paths."""
|
491 |
+
|
492 |
+
def _sample_data(
|
493 |
+
self,
|
494 |
+
length: int,
|
495 |
+
p: Tuple[float, ...] = (0.5,),
|
496 |
+
low: float = 0.,
|
497 |
+
high: float = 1.,
|
498 |
+
):
|
499 |
+
graph = self._random_er_graph(
|
500 |
+
nb_nodes=length,
|
501 |
+
p=self._rng.choice(p),
|
502 |
+
directed=True,
|
503 |
+
acyclic=True,
|
504 |
+
weighted=True,
|
505 |
+
low=low,
|
506 |
+
high=high)
|
507 |
+
source_node = self._rng.choice(length)
|
508 |
+
return [graph, source_node]
|
509 |
+
|
510 |
+
|
511 |
+
class FloydWarshallSampler(Sampler):
|
512 |
+
"""Sampler for all-pairs shortest paths."""
|
513 |
+
|
514 |
+
def _sample_data(
|
515 |
+
self,
|
516 |
+
length: int,
|
517 |
+
p: Tuple[float, ...] = (0.5,),
|
518 |
+
low: float = 0.,
|
519 |
+
high: float = 1.,
|
520 |
+
):
|
521 |
+
graph = self._random_er_graph(
|
522 |
+
nb_nodes=length,
|
523 |
+
p=self._rng.choice(p),
|
524 |
+
directed=False,
|
525 |
+
acyclic=False,
|
526 |
+
weighted=True,
|
527 |
+
low=low,
|
528 |
+
high=high)
|
529 |
+
return [graph]
|
530 |
+
|
531 |
+
|
532 |
+
class SccSampler(Sampler):
|
533 |
+
"""Sampler for strongly connected component (SCC) tasks."""
|
534 |
+
|
535 |
+
def _sample_data(
|
536 |
+
self,
|
537 |
+
length: int,
|
538 |
+
k: int = 4,
|
539 |
+
p: Tuple[float, ...] = (0.5,),
|
540 |
+
eps: float = 0.01,
|
541 |
+
):
|
542 |
+
graph = self._random_community_graph(
|
543 |
+
nb_nodes=length, k=k, p=self._rng.choice(p), eps=eps,
|
544 |
+
directed=True, acyclic=False, weighted=False)
|
545 |
+
return [graph]
|
546 |
+
|
547 |
+
|
548 |
+
class BipartiteSampler(Sampler):
|
549 |
+
"""Sampler for bipartite matching-based flow networks."""
|
550 |
+
|
551 |
+
def _sample_data(
|
552 |
+
self,
|
553 |
+
length: int,
|
554 |
+
length_2: Optional[int] = None,
|
555 |
+
p: Tuple[float, ...] = (0.3,),
|
556 |
+
):
|
557 |
+
if length_2 is None:
|
558 |
+
# Assume provided length is total length.
|
559 |
+
length_2 = length // 2
|
560 |
+
length -= length_2
|
561 |
+
graph = self._random_bipartite_graph(n=length, m=length_2,
|
562 |
+
p=self._rng.choice(p))
|
563 |
+
return [graph, length, length_2, 0, length + length_2 + 1]
|
564 |
+
|
565 |
+
|
566 |
+
class MatcherSampler(Sampler):
|
567 |
+
"""String matching sampler; embeds needle in a random haystack."""
|
568 |
+
|
569 |
+
def _sample_data(
|
570 |
+
self,
|
571 |
+
length: int, # length of haystack + needle, i.e., total number of nodes
|
572 |
+
length_needle: Optional[int] = None,
|
573 |
+
chars: int = 4,
|
574 |
+
):
|
575 |
+
if length_needle is None:
|
576 |
+
if length < 5:
|
577 |
+
length_needle = 1
|
578 |
+
else:
|
579 |
+
length_needle = length // 5
|
580 |
+
elif length_needle < 0: # randomize needle length
|
581 |
+
length_needle = self._rng.randint(1, high=1 - length_needle)
|
582 |
+
length_haystack = length - length_needle
|
583 |
+
needle = self._random_string(length=length_needle, chars=chars)
|
584 |
+
haystack = self._random_string(length=length_haystack, chars=chars)
|
585 |
+
embed_pos = self._rng.choice(length_haystack - length_needle)
|
586 |
+
haystack[embed_pos:embed_pos + length_needle] = needle
|
587 |
+
return [haystack, needle]
|
588 |
+
|
589 |
+
|
590 |
+
class SegmentsSampler(Sampler):
|
591 |
+
"""Two-segment sampler of points from (U[0, 1], U[0, 1])."""
|
592 |
+
|
593 |
+
def _sample_data(self, length: int, low: float = 0., high: float = 1.):
|
594 |
+
del length # There are exactly four endpoints.
|
595 |
+
|
596 |
+
# Quick CCW check (ignoring collinearity) for rejection sampling
|
597 |
+
def ccw(x_a, y_a, x_b, y_b, x_c, y_c):
|
598 |
+
return (y_c - y_a) * (x_b - x_a) > (y_b - y_a) * (x_c - x_a)
|
599 |
+
def intersect(xs, ys):
|
600 |
+
return ccw(xs[0], ys[0], xs[2], ys[2], xs[3], ys[3]) != ccw(
|
601 |
+
xs[1], ys[1], xs[2], ys[2], xs[3], ys[3]) and ccw(
|
602 |
+
xs[0], ys[0], xs[1], ys[1], xs[2], ys[2]) != ccw(
|
603 |
+
xs[0], ys[0], xs[1], ys[1], xs[3], ys[3])
|
604 |
+
|
605 |
+
# Decide (with uniform probability) should this sample intersect
|
606 |
+
coin_flip = self._rng.binomial(1, 0.5)
|
607 |
+
|
608 |
+
xs = self._random_sequence(length=4, low=low, high=high)
|
609 |
+
ys = self._random_sequence(length=4, low=low, high=high)
|
610 |
+
|
611 |
+
while intersect(xs, ys) != coin_flip:
|
612 |
+
xs = self._random_sequence(length=4, low=low, high=high)
|
613 |
+
ys = self._random_sequence(length=4, low=low, high=high)
|
614 |
+
|
615 |
+
return [xs, ys]
|
616 |
+
|
617 |
+
|
618 |
+
class ConvexHullSampler(Sampler):
|
619 |
+
"""Convex hull sampler of points over a disk of radius r."""
|
620 |
+
|
621 |
+
def _sample_data(self, length: int, origin_x: float = 0.,
|
622 |
+
origin_y: float = 0., radius: float = 2.):
|
623 |
+
|
624 |
+
thetas = self._random_sequence(length=length, low=0.0, high=2.0 * np.pi)
|
625 |
+
rs = radius * np.sqrt(
|
626 |
+
self._random_sequence(length=length, low=0.0, high=1.0))
|
627 |
+
|
628 |
+
xs = rs * np.cos(thetas) + origin_x
|
629 |
+
ys = rs * np.sin(thetas) + origin_y
|
630 |
+
|
631 |
+
return [xs, ys]
|
632 |
+
|
633 |
+
|
634 |
+
SAMPLERS = {
|
635 |
+
'insertion_sort': SortingSampler,
|
636 |
+
'bubble_sort': SortingSampler,
|
637 |
+
'heapsort': SortingSampler,
|
638 |
+
'quicksort': SortingSampler,
|
639 |
+
'quickselect': SortingSampler,
|
640 |
+
'minimum': SortingSampler,
|
641 |
+
'binary_search': SearchSampler,
|
642 |
+
'find_maximum_subarray': MaxSubarraySampler,
|
643 |
+
'find_maximum_subarray_kadane': MaxSubarraySampler,
|
644 |
+
'matrix_chain_order': SortingSampler,
|
645 |
+
'lcs_length': LCSSampler,
|
646 |
+
'optimal_bst': OptimalBSTSampler,
|
647 |
+
'activity_selector': ActivitySampler,
|
648 |
+
'task_scheduling': TaskSampler,
|
649 |
+
'dfs': DfsSampler,
|
650 |
+
'topological_sort': TopoSampler,
|
651 |
+
'strongly_connected_components': SccSampler,
|
652 |
+
'articulation_points': ArticulationSampler,
|
653 |
+
'bridges': ArticulationSampler,
|
654 |
+
'bfs': BfsSampler,
|
655 |
+
'mst_kruskal': MSTSampler,
|
656 |
+
'mst_prim': BellmanFordSampler,
|
657 |
+
'bellman_ford': BellmanFordSampler,
|
658 |
+
'dag_shortest_paths': DAGPathSampler,
|
659 |
+
'dijkstra': BellmanFordSampler,
|
660 |
+
'floyd_warshall': FloydWarshallSampler,
|
661 |
+
'bipartite_matching': BipartiteSampler,
|
662 |
+
'naive_string_matcher': MatcherSampler,
|
663 |
+
'kmp_matcher': MatcherSampler,
|
664 |
+
'segments_intersect': SegmentsSampler,
|
665 |
+
'graham_scan': ConvexHullSampler,
|
666 |
+
'jarvis_march': ConvexHullSampler,
|
667 |
+
}
|
668 |
+
|
669 |
+
|
670 |
+
def _batch_io(traj_io: Trajectories) -> Trajectory:
|
671 |
+
"""Batches a trajectory of input/output samples along the time axis per probe.
|
672 |
+
|
673 |
+
Args:
|
674 |
+
traj_io: An i/o trajectory of `DataPoint`s indexed by time then probe.
|
675 |
+
|
676 |
+
Returns:
|
677 |
+
A |num probes| list of `DataPoint`s with the time axis stacked into `data`.
|
678 |
+
"""
|
679 |
+
|
680 |
+
assert traj_io # non-empty
|
681 |
+
for sample_io in traj_io:
|
682 |
+
for i, dp in enumerate(sample_io):
|
683 |
+
assert dp.data.shape[0] == 1 # batching axis
|
684 |
+
assert traj_io[0][i].name == dp.name
|
685 |
+
|
686 |
+
return jax.tree_util.tree_map(lambda *x: np.concatenate(x), *traj_io)
|
687 |
+
|
688 |
+
|
689 |
+
def _batch_hints(
|
690 |
+
traj_hints: Trajectories, min_steps: int) -> Tuple[Trajectory, List[int]]:
|
691 |
+
"""Batches a trajectory of hints samples along the time axis per probe.
|
692 |
+
|
693 |
+
Unlike i/o, hints have a variable-length time dimension. Before batching, each
|
694 |
+
trajectory is padded to the maximum trajectory length.
|
695 |
+
|
696 |
+
Args:
|
697 |
+
traj_hints: A hint trajectory of `DataPoints`s indexed by time then probe
|
698 |
+
min_steps: Hints will be padded at least to this length - if any hint is
|
699 |
+
longer than this, the greater length will be used.
|
700 |
+
|
701 |
+
Returns:
|
702 |
+
A |num probes| list of `DataPoint`s with the time axis stacked into `data`,
|
703 |
+
and a |sample| list containing the length of each trajectory.
|
704 |
+
"""
|
705 |
+
|
706 |
+
max_steps = min_steps
|
707 |
+
assert traj_hints # non-empty
|
708 |
+
for sample_hint in traj_hints:
|
709 |
+
for dp in sample_hint:
|
710 |
+
assert dp.data.shape[1] == 1 # batching axis
|
711 |
+
if dp.data.shape[0] > max_steps:
|
712 |
+
max_steps = dp.data.shape[0]
|
713 |
+
time_and_batch = (max_steps, len(traj_hints))
|
714 |
+
|
715 |
+
# Create zero-filled space for the batched hints, then copy each hint
|
716 |
+
# up to the corresponding length.
|
717 |
+
batched_traj = jax.tree_util.tree_map(
|
718 |
+
lambda x: np.zeros(time_and_batch + x.shape[2:]),
|
719 |
+
traj_hints[0])
|
720 |
+
hint_lengths = np.zeros(len(traj_hints))
|
721 |
+
|
722 |
+
for sample_idx, cur_sample in enumerate(traj_hints):
|
723 |
+
for i in range(len(cur_sample)):
|
724 |
+
assert batched_traj[i].name == cur_sample[i].name
|
725 |
+
cur_data = cur_sample[i].data
|
726 |
+
cur_length = cur_data.shape[0]
|
727 |
+
batched_traj[i].data[:cur_length, sample_idx:sample_idx+1] = cur_data
|
728 |
+
if i > 0:
|
729 |
+
assert hint_lengths[sample_idx] == cur_length
|
730 |
+
else:
|
731 |
+
hint_lengths[sample_idx] = cur_length
|
732 |
+
return batched_traj, hint_lengths
|
733 |
+
|
734 |
+
|
735 |
+
def _subsample_data(
|
736 |
+
trajectory: Trajectory,
|
737 |
+
idx: List[int],
|
738 |
+
axis: int = 0,
|
739 |
+
) -> Trajectory:
|
740 |
+
"""New `Trajectory` where each `DataPoint`'s data is subsampled along axis."""
|
741 |
+
sampled_traj = []
|
742 |
+
for dp in trajectory:
|
743 |
+
sampled_data = np.take(dp.data, idx, axis=axis)
|
744 |
+
sampled_traj.append(
|
745 |
+
probing.DataPoint(dp.name, dp.location, dp.type_, sampled_data))
|
746 |
+
return sampled_traj
|
747 |
+
|
748 |
+
|
749 |
+
def _preprocess_permutations(probes, enforce_permutations):
|
750 |
+
"""Replace should-be permutations with proper permutation pointer + mask."""
|
751 |
+
output = []
|
752 |
+
for x in probes:
|
753 |
+
if x.type_ != specs.Type.SHOULD_BE_PERMUTATION:
|
754 |
+
output.append(x)
|
755 |
+
continue
|
756 |
+
assert x.location == specs.Location.NODE
|
757 |
+
if enforce_permutations:
|
758 |
+
new_x, mask = probing.predecessor_to_cyclic_predecessor_and_first(x.data)
|
759 |
+
output.append(
|
760 |
+
probing.DataPoint(
|
761 |
+
name=x.name,
|
762 |
+
location=x.location,
|
763 |
+
type_=specs.Type.PERMUTATION_POINTER,
|
764 |
+
data=new_x))
|
765 |
+
output.append(
|
766 |
+
probing.DataPoint(
|
767 |
+
name=x.name + '_mask',
|
768 |
+
location=x.location,
|
769 |
+
type_=specs.Type.MASK_ONE,
|
770 |
+
data=mask))
|
771 |
+
else:
|
772 |
+
output.append(probing.DataPoint(name=x.name, location=x.location,
|
773 |
+
type_=specs.Type.POINTER, data=x.data))
|
774 |
+
return output
|
775 |
+
|
776 |
+
|
777 |
+
def process_permutations(spec, sample_iterator, enforce_permutations):
|
778 |
+
"""Replace should-be permutations with proper permutation pointer + mask."""
|
779 |
+
def _iterate():
|
780 |
+
while True:
|
781 |
+
feedback = next(sample_iterator)
|
782 |
+
features = feedback.features
|
783 |
+
inputs = _preprocess_permutations(features.inputs, enforce_permutations)
|
784 |
+
hints = _preprocess_permutations(features.hints, enforce_permutations)
|
785 |
+
outputs = _preprocess_permutations(feedback.outputs, enforce_permutations)
|
786 |
+
features = features._replace(inputs=tuple(inputs),
|
787 |
+
hints=tuple(hints))
|
788 |
+
feedback = feedback._replace(features=features,
|
789 |
+
outputs=outputs)
|
790 |
+
yield feedback
|
791 |
+
|
792 |
+
new_spec = {}
|
793 |
+
for k in spec:
|
794 |
+
if (spec[k][1] == specs.Location.NODE and
|
795 |
+
spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION):
|
796 |
+
if enforce_permutations:
|
797 |
+
new_spec[k] = (spec[k][0], spec[k][1], specs.Type.PERMUTATION_POINTER)
|
798 |
+
new_spec[k + '_mask'] = (spec[k][0], spec[k][1], specs.Type.MASK_ONE)
|
799 |
+
else:
|
800 |
+
new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER)
|
801 |
+
else:
|
802 |
+
new_spec[k] = spec[k]
|
803 |
+
|
804 |
+
return new_spec, _iterate()
|
805 |
+
|
806 |
+
|
807 |
+
def process_pred_as_input(spec, sample_iterator):
|
808 |
+
"""Move pred_h hint to pred input."""
|
809 |
+
def _iterate():
|
810 |
+
while True:
|
811 |
+
feedback = next(sample_iterator)
|
812 |
+
features = feedback.features
|
813 |
+
pred_h = [h for h in features.hints if h.name == 'pred_h']
|
814 |
+
if pred_h:
|
815 |
+
assert len(pred_h) == 1
|
816 |
+
pred_h = pred_h[0]
|
817 |
+
hints = [h for h in features.hints if h.name != 'pred_h']
|
818 |
+
for i in range(len(features.lengths)):
|
819 |
+
assert np.sum(np.abs(pred_h.data[1:int(features.lengths[i]), i] -
|
820 |
+
pred_h.data[0, i])) == 0.0
|
821 |
+
inputs = tuple(features.inputs) + (
|
822 |
+
probing.DataPoint(name='pred', location=pred_h.location,
|
823 |
+
type_=pred_h.type_, data=pred_h.data[0]),)
|
824 |
+
features = features._replace(inputs=tuple(inputs),
|
825 |
+
hints=tuple(hints))
|
826 |
+
feedback = feedback._replace(features=features)
|
827 |
+
yield feedback
|
828 |
+
|
829 |
+
new_spec = {}
|
830 |
+
for k in spec:
|
831 |
+
if k == 'pred_h':
|
832 |
+
assert spec[k] == (specs.Stage.HINT, specs.Location.NODE,
|
833 |
+
specs.Type.POINTER)
|
834 |
+
new_spec['pred'] = (specs.Stage.INPUT, specs.Location.NODE,
|
835 |
+
specs.Type.POINTER)
|
836 |
+
else:
|
837 |
+
new_spec[k] = spec[k]
|
838 |
+
|
839 |
+
return new_spec, _iterate()
|
840 |
+
|
841 |
+
|
842 |
+
def process_random_pos(sample_iterator, rng):
|
843 |
+
"""Randomize the `pos` input from a sampler.
|
844 |
+
|
845 |
+
The `pos` input is, by default, a scalar uniformly spaced between 0 and 1
|
846 |
+
across the nodes. The exception are string algorithms (naive_string_matcher,
|
847 |
+
kmp_string_matcher and lcs_length), where the `pos` sequence is split into
|
848 |
+
needle and haystack (or first and second string, for lcs_length). Here
|
849 |
+
we replace the uniformly spaced `pos` with an ordered sequence of random
|
850 |
+
scalars, or, for string algorithms, two ordered sequences of random scalars.
|
851 |
+
|
852 |
+
Args:
|
853 |
+
sample_iterator: An iterator producing samples with non-random `pos` inputs.
|
854 |
+
rng: Numpy random generator
|
855 |
+
Returns:
|
856 |
+
An iterator returning the samples with randomized `pos` inputs.
|
857 |
+
"""
|
858 |
+
def _iterate():
|
859 |
+
while True:
|
860 |
+
feedback = next(sample_iterator)
|
861 |
+
inputs = feedback.features.inputs
|
862 |
+
pos, = [x for x in inputs if x.name == 'pos']
|
863 |
+
batch_size, num_nodes = pos.data.shape
|
864 |
+
unsorted = rng.uniform(size=(batch_size, num_nodes))
|
865 |
+
new_pos = []
|
866 |
+
for i in range(batch_size): # we check one example at a time.
|
867 |
+
# We find if there are splits in the pos sequence, marked by zeros.
|
868 |
+
# We know there will always be at least 1 zero, if there's no split.
|
869 |
+
split, = np.where(pos.data[i] == 0)
|
870 |
+
split = np.concatenate([split, [num_nodes]])
|
871 |
+
# We construct the randomized pos by sorting the random values in each
|
872 |
+
# split and concatenating them.
|
873 |
+
new_pos.append(
|
874 |
+
np.concatenate([np.sort(unsorted[i, split[j]:split[j+1]])
|
875 |
+
for j in range(len(split) - 1)]))
|
876 |
+
pos.data = np.array(new_pos)
|
877 |
+
inputs = [(pos if x.name == 'pos' else x) for x in inputs]
|
878 |
+
features = feedback.features._replace(inputs=inputs)
|
879 |
+
feedback = feedback._replace(features=features)
|
880 |
+
yield feedback
|
881 |
+
|
882 |
+
return _iterate()
|
benchmarks/CLRS/env/samplers_test.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Unit tests for `samplers.py`."""
|
17 |
+
|
18 |
+
from absl.testing import absltest
|
19 |
+
from absl.testing import parameterized
|
20 |
+
|
21 |
+
import chex
|
22 |
+
from clrs._src import probing
|
23 |
+
from clrs._src import samplers
|
24 |
+
from clrs._src import specs
|
25 |
+
import jax
|
26 |
+
import numpy as np
|
27 |
+
|
28 |
+
|
29 |
+
class SamplersTest(parameterized.TestCase):
|
30 |
+
|
31 |
+
@parameterized.parameters(*specs.CLRS_30_ALGS)
|
32 |
+
def test_sampler_determinism(self, name):
|
33 |
+
num_samples = 3
|
34 |
+
num_nodes = 10
|
35 |
+
sampler, _ = samplers.build_sampler(name, num_samples, num_nodes)
|
36 |
+
|
37 |
+
np.random.seed(47) # Set seed
|
38 |
+
feedback = sampler.next()
|
39 |
+
expected = feedback.outputs[0].data.copy()
|
40 |
+
|
41 |
+
np.random.seed(48) # Set a different seed
|
42 |
+
feedback = sampler.next()
|
43 |
+
actual = feedback.outputs[0].data.copy()
|
44 |
+
|
45 |
+
# Validate that datasets are the same.
|
46 |
+
np.testing.assert_array_equal(expected, actual)
|
47 |
+
|
48 |
+
@parameterized.parameters(*specs.CLRS_30_ALGS)
|
49 |
+
def test_sampler_batch_determinism(self, name):
|
50 |
+
num_samples = 10
|
51 |
+
batch_size = 5
|
52 |
+
num_nodes = 10
|
53 |
+
seed = 0
|
54 |
+
sampler_1, _ = samplers.build_sampler(
|
55 |
+
name, num_samples, num_nodes, seed=seed)
|
56 |
+
sampler_2, _ = samplers.build_sampler(
|
57 |
+
name, num_samples, num_nodes, seed=seed)
|
58 |
+
|
59 |
+
feedback_1 = sampler_1.next(batch_size)
|
60 |
+
feedback_2 = sampler_2.next(batch_size)
|
61 |
+
|
62 |
+
# Validate that datasets are the same.
|
63 |
+
jax.tree_util.tree_map(np.testing.assert_array_equal, feedback_1,
|
64 |
+
feedback_2)
|
65 |
+
|
66 |
+
def test_end_to_end(self):
|
67 |
+
num_samples = 7
|
68 |
+
num_nodes = 3
|
69 |
+
sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes)
|
70 |
+
feedback = sampler.next()
|
71 |
+
|
72 |
+
inputs = feedback.features.inputs
|
73 |
+
self.assertLen(inputs, 4)
|
74 |
+
self.assertEqual(inputs[0].name, "pos")
|
75 |
+
self.assertEqual(inputs[0].data.shape, (num_samples, num_nodes))
|
76 |
+
|
77 |
+
outputs = feedback.outputs
|
78 |
+
self.assertLen(outputs, 1)
|
79 |
+
self.assertEqual(outputs[0].name, "pi")
|
80 |
+
self.assertEqual(outputs[0].data.shape, (num_samples, num_nodes))
|
81 |
+
|
82 |
+
def test_batch_size(self):
|
83 |
+
num_samples = 7
|
84 |
+
num_nodes = 3
|
85 |
+
sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes)
|
86 |
+
|
87 |
+
# Full-batch.
|
88 |
+
feedback = sampler.next()
|
89 |
+
for dp in feedback.features.inputs: # [B, ...]
|
90 |
+
self.assertEqual(dp.data.shape[0], num_samples)
|
91 |
+
|
92 |
+
for dp in feedback.outputs: # [B, ...]
|
93 |
+
self.assertEqual(dp.data.shape[0], num_samples)
|
94 |
+
|
95 |
+
for dp in feedback.features.hints: # [T, B, ...]
|
96 |
+
self.assertEqual(dp.data.shape[1], num_samples)
|
97 |
+
|
98 |
+
self.assertLen(feedback.features.lengths, num_samples)
|
99 |
+
|
100 |
+
# Specified batch.
|
101 |
+
batch_size = 5
|
102 |
+
feedback = sampler.next(batch_size)
|
103 |
+
|
104 |
+
for dp in feedback.features.inputs: # [B, ...]
|
105 |
+
self.assertEqual(dp.data.shape[0], batch_size)
|
106 |
+
|
107 |
+
for dp in feedback.outputs: # [B, ...]
|
108 |
+
self.assertEqual(dp.data.shape[0], batch_size)
|
109 |
+
|
110 |
+
for dp in feedback.features.hints: # [T, B, ...]
|
111 |
+
self.assertEqual(dp.data.shape[1], batch_size)
|
112 |
+
|
113 |
+
self.assertLen(feedback.features.lengths, batch_size)
|
114 |
+
|
115 |
+
def test_batch_io(self):
|
116 |
+
sample = [
|
117 |
+
probing.DataPoint(
|
118 |
+
name="x",
|
119 |
+
location=specs.Location.NODE,
|
120 |
+
type_=specs.Type.SCALAR,
|
121 |
+
data=np.zeros([1, 3]),
|
122 |
+
),
|
123 |
+
probing.DataPoint(
|
124 |
+
name="y",
|
125 |
+
location=specs.Location.EDGE,
|
126 |
+
type_=specs.Type.MASK,
|
127 |
+
data=np.zeros([1, 3, 3]),
|
128 |
+
),
|
129 |
+
]
|
130 |
+
|
131 |
+
trajectory = [sample.copy(), sample.copy(), sample.copy(), sample.copy()]
|
132 |
+
batched = samplers._batch_io(trajectory)
|
133 |
+
|
134 |
+
np.testing.assert_array_equal(batched[0].data, np.zeros([4, 3]))
|
135 |
+
np.testing.assert_array_equal(batched[1].data, np.zeros([4, 3, 3]))
|
136 |
+
|
137 |
+
def test_batch_hint(self):
|
138 |
+
sample0 = [
|
139 |
+
probing.DataPoint(
|
140 |
+
name="x",
|
141 |
+
location=specs.Location.NODE,
|
142 |
+
type_=specs.Type.MASK,
|
143 |
+
data=np.zeros([2, 1, 3]),
|
144 |
+
),
|
145 |
+
probing.DataPoint(
|
146 |
+
name="y",
|
147 |
+
location=specs.Location.NODE,
|
148 |
+
type_=specs.Type.POINTER,
|
149 |
+
data=np.zeros([2, 1, 3]),
|
150 |
+
),
|
151 |
+
]
|
152 |
+
|
153 |
+
sample1 = [
|
154 |
+
probing.DataPoint(
|
155 |
+
name="x",
|
156 |
+
location=specs.Location.NODE,
|
157 |
+
type_=specs.Type.MASK,
|
158 |
+
data=np.zeros([1, 1, 3]),
|
159 |
+
),
|
160 |
+
probing.DataPoint(
|
161 |
+
name="y",
|
162 |
+
location=specs.Location.NODE,
|
163 |
+
type_=specs.Type.POINTER,
|
164 |
+
data=np.zeros([1, 1, 3]),
|
165 |
+
),
|
166 |
+
]
|
167 |
+
|
168 |
+
trajectory = [sample0, sample1]
|
169 |
+
batched, lengths = samplers._batch_hints(trajectory, 0)
|
170 |
+
|
171 |
+
np.testing.assert_array_equal(batched[0].data, np.zeros([2, 2, 3]))
|
172 |
+
np.testing.assert_array_equal(batched[1].data, np.zeros([2, 2, 3]))
|
173 |
+
np.testing.assert_array_equal(lengths, np.array([2, 1]))
|
174 |
+
|
175 |
+
batched, lengths = samplers._batch_hints(trajectory, 5)
|
176 |
+
|
177 |
+
np.testing.assert_array_equal(batched[0].data, np.zeros([5, 2, 3]))
|
178 |
+
np.testing.assert_array_equal(batched[1].data, np.zeros([5, 2, 3]))
|
179 |
+
np.testing.assert_array_equal(lengths, np.array([2, 1]))
|
180 |
+
|
181 |
+
def test_padding(self):
|
182 |
+
lens = np.random.choice(10, (10,), replace=True) + 1
|
183 |
+
trajectory = []
|
184 |
+
for len_ in lens:
|
185 |
+
trajectory.append([
|
186 |
+
probing.DataPoint(
|
187 |
+
name="x",
|
188 |
+
location=specs.Location.NODE,
|
189 |
+
type_=specs.Type.MASK,
|
190 |
+
data=np.ones([len_, 1, 3]),
|
191 |
+
)
|
192 |
+
])
|
193 |
+
|
194 |
+
batched, lengths = samplers._batch_hints(trajectory, 0)
|
195 |
+
np.testing.assert_array_equal(lengths, lens)
|
196 |
+
|
197 |
+
for i in range(len(lens)):
|
198 |
+
ones = batched[0].data[:lens[i], i, :]
|
199 |
+
zeros = batched[0].data[lens[i]:, i, :]
|
200 |
+
np.testing.assert_array_equal(ones, np.ones_like(ones))
|
201 |
+
np.testing.assert_array_equal(zeros, np.zeros_like(zeros))
|
202 |
+
|
203 |
+
|
204 |
+
class ProcessRandomPosTest(parameterized.TestCase):
|
205 |
+
|
206 |
+
@parameterized.parameters(["insertion_sort", "naive_string_matcher"])
|
207 |
+
def test_random_pos(self, algorithm_name):
|
208 |
+
batch_size, length = 12, 10
|
209 |
+
def _make_sampler():
|
210 |
+
sampler, _ = samplers.build_sampler(
|
211 |
+
algorithm_name,
|
212 |
+
seed=0,
|
213 |
+
num_samples=100,
|
214 |
+
length=length,
|
215 |
+
)
|
216 |
+
while True:
|
217 |
+
yield sampler.next(batch_size)
|
218 |
+
sampler_1 = _make_sampler()
|
219 |
+
sampler_2 = _make_sampler()
|
220 |
+
sampler_2 = samplers.process_random_pos(sampler_2, np.random.RandomState(0))
|
221 |
+
|
222 |
+
batch_without_rand_pos = next(sampler_1)
|
223 |
+
batch_with_rand_pos = next(sampler_2)
|
224 |
+
pos_idx = [x.name for x in batch_without_rand_pos.features.inputs].index(
|
225 |
+
"pos")
|
226 |
+
fixed_pos = batch_without_rand_pos.features.inputs[pos_idx]
|
227 |
+
rand_pos = batch_with_rand_pos.features.inputs[pos_idx]
|
228 |
+
self.assertEqual(rand_pos.location, specs.Location.NODE)
|
229 |
+
self.assertEqual(rand_pos.type_, specs.Type.SCALAR)
|
230 |
+
self.assertEqual(rand_pos.data.shape, (batch_size, length))
|
231 |
+
self.assertEqual(rand_pos.data.shape, fixed_pos.data.shape)
|
232 |
+
self.assertEqual(rand_pos.type_, fixed_pos.type_)
|
233 |
+
self.assertEqual(rand_pos.location, fixed_pos.location)
|
234 |
+
|
235 |
+
assert (rand_pos.data.std(axis=0) > 1e-3).all()
|
236 |
+
assert (fixed_pos.data.std(axis=0) < 1e-9).all()
|
237 |
+
if "string" in algorithm_name:
|
238 |
+
expected = np.concatenate([np.arange(4*length//5)/(4*length//5),
|
239 |
+
np.arange(length//5)/(length//5)])
|
240 |
+
else:
|
241 |
+
expected = np.arange(length)/length
|
242 |
+
np.testing.assert_array_equal(
|
243 |
+
fixed_pos.data, np.broadcast_to(expected, (batch_size, length)))
|
244 |
+
|
245 |
+
batch_with_rand_pos.features.inputs[pos_idx] = fixed_pos
|
246 |
+
chex.assert_trees_all_equal(batch_with_rand_pos, batch_without_rand_pos)
|
247 |
+
|
248 |
+
|
249 |
+
if __name__ == "__main__":
|
250 |
+
absltest.main()
|
benchmarks/CLRS/env/specs.py
ADDED
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Algorithm specs.
|
17 |
+
|
18 |
+
The "spec" of each algorithm is a static set of `(stage, loc, type)`-tuples.
|
19 |
+
|
20 |
+
- `stage`: One of either an `input`, `output` or `hint`
|
21 |
+
- `location`: Each datum is associated with either the `node`, `edge` or `graph`
|
22 |
+
- `type`: Either a `scalar`, `categorical`, `mask`, `mask_one` or `pointer`
|
23 |
+
|
24 |
+
The dataflow for an algorithm is represented by `(stage, loc, type, data)`
|
25 |
+
"probes" that are valid under that algorithm's spec. It contains a single
|
26 |
+
snapshot for each `input` and `output` and a time-series of intermediate
|
27 |
+
algorithmic states (`hint`).
|
28 |
+
|
29 |
+
At minimum, each node contains a `pos` probe that serves as a unique index e.g.
|
30 |
+
for representing sequential data where appropriate
|
31 |
+
"""
|
32 |
+
|
33 |
+
import types
|
34 |
+
from typing import Dict, Tuple
|
35 |
+
|
36 |
+
|
37 |
+
class Stage:
|
38 |
+
INPUT = 'input'
|
39 |
+
OUTPUT = 'output'
|
40 |
+
HINT = 'hint'
|
41 |
+
|
42 |
+
|
43 |
+
class Location:
|
44 |
+
NODE = 'node'
|
45 |
+
EDGE = 'edge'
|
46 |
+
GRAPH = 'graph'
|
47 |
+
|
48 |
+
|
49 |
+
class Type:
|
50 |
+
SCALAR = 'scalar'
|
51 |
+
CATEGORICAL = 'categorical'
|
52 |
+
MASK = 'mask'
|
53 |
+
MASK_ONE = 'mask_one'
|
54 |
+
POINTER = 'pointer'
|
55 |
+
SHOULD_BE_PERMUTATION = 'should_be_permutation'
|
56 |
+
PERMUTATION_POINTER = 'permutation_pointer'
|
57 |
+
SOFT_POINTER = 'soft_pointer'
|
58 |
+
|
59 |
+
|
60 |
+
class OutputClass:
|
61 |
+
POSITIVE = 1
|
62 |
+
NEGATIVE = 0
|
63 |
+
MASKED = -1
|
64 |
+
|
65 |
+
Spec = Dict[str, Tuple[str, str, str]]
|
66 |
+
|
67 |
+
CLRS_30_ALGS = [
|
68 |
+
'articulation_points',
|
69 |
+
'activity_selector',
|
70 |
+
'bellman_ford',
|
71 |
+
'bfs',
|
72 |
+
'binary_search',
|
73 |
+
'bridges',
|
74 |
+
'bubble_sort',
|
75 |
+
'dag_shortest_paths',
|
76 |
+
'dfs',
|
77 |
+
'dijkstra',
|
78 |
+
'find_maximum_subarray_kadane',
|
79 |
+
'floyd_warshall',
|
80 |
+
'graham_scan',
|
81 |
+
'heapsort',
|
82 |
+
'insertion_sort',
|
83 |
+
'jarvis_march',
|
84 |
+
'kmp_matcher',
|
85 |
+
'lcs_length',
|
86 |
+
'matrix_chain_order',
|
87 |
+
'minimum',
|
88 |
+
'mst_kruskal',
|
89 |
+
'mst_prim',
|
90 |
+
'naive_string_matcher',
|
91 |
+
'optimal_bst',
|
92 |
+
'quickselect',
|
93 |
+
'quicksort',
|
94 |
+
'segments_intersect',
|
95 |
+
'strongly_connected_components',
|
96 |
+
'task_scheduling',
|
97 |
+
'topological_sort',
|
98 |
+
]
|
99 |
+
|
100 |
+
|
101 |
+
ALGO_IDX_INPUT_NAME = 'algo_idx'
|
102 |
+
|
103 |
+
# Algorithms have varying numbers of signals they are evaluated on.
|
104 |
+
# To compensate for that, we issue more samples for those who use a small
|
105 |
+
# number of signals.
|
106 |
+
CLRS_30_ALGS_SETTINGS = {alg: {'num_samples_multiplier': 1}
|
107 |
+
for alg in CLRS_30_ALGS}
|
108 |
+
CLRS_30_ALGS_SETTINGS['find_maximum_subarray_kadane'][
|
109 |
+
'num_samples_multiplier'] = 32
|
110 |
+
for alg in ['quickselect', 'minimum', 'binary_search', 'naive_string_matcher',
|
111 |
+
'kmp_matcher', 'segments_intersect']:
|
112 |
+
CLRS_30_ALGS_SETTINGS[alg]['num_samples_multiplier'] = 64
|
113 |
+
|
114 |
+
|
115 |
+
SPECS = types.MappingProxyType({
|
116 |
+
'insertion_sort': {
|
117 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
118 |
+
'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
119 |
+
'pred': (Stage.OUTPUT, Location.NODE, Type.SHOULD_BE_PERMUTATION),
|
120 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
121 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
122 |
+
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE)
|
123 |
+
},
|
124 |
+
'bubble_sort': {
|
125 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
126 |
+
'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
127 |
+
'pred': (Stage.OUTPUT, Location.NODE, Type.SHOULD_BE_PERMUTATION),
|
128 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
129 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
130 |
+
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE)
|
131 |
+
},
|
132 |
+
'heapsort': {
|
133 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
134 |
+
'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
135 |
+
'pred': (Stage.OUTPUT, Location.NODE, Type.SHOULD_BE_PERMUTATION),
|
136 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
137 |
+
'parent': (Stage.HINT, Location.NODE, Type.POINTER),
|
138 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
139 |
+
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
140 |
+
'largest': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
141 |
+
'heap_size': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
142 |
+
'phase': (Stage.HINT, Location.GRAPH, Type.CATEGORICAL)
|
143 |
+
},
|
144 |
+
'quicksort': {
|
145 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
146 |
+
'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
147 |
+
'pred': (Stage.OUTPUT, Location.NODE, Type.SHOULD_BE_PERMUTATION),
|
148 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
149 |
+
'p': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
150 |
+
'r': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
151 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
152 |
+
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE)
|
153 |
+
},
|
154 |
+
'quickselect': {
|
155 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
156 |
+
'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
157 |
+
'median': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
|
158 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
159 |
+
'p': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
160 |
+
'r': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
161 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
162 |
+
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
163 |
+
'i_rank': (Stage.HINT, Location.GRAPH, Type.SCALAR),
|
164 |
+
'target': (Stage.HINT, Location.GRAPH, Type.SCALAR)
|
165 |
+
},
|
166 |
+
'minimum': {
|
167 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
168 |
+
'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
169 |
+
'min': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
|
170 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
171 |
+
'min_h': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
172 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE)
|
173 |
+
},
|
174 |
+
'binary_search': {
|
175 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
176 |
+
'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
177 |
+
'target': (Stage.INPUT, Location.GRAPH, Type.SCALAR),
|
178 |
+
'return': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
|
179 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
180 |
+
'low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
181 |
+
'high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
182 |
+
'mid': (Stage.HINT, Location.NODE, Type.MASK_ONE)
|
183 |
+
},
|
184 |
+
'find_maximum_subarray': {
|
185 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
186 |
+
'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
187 |
+
'start': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
|
188 |
+
'end': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
|
189 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
190 |
+
'low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
191 |
+
'high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
192 |
+
'mid': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
193 |
+
'left_low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
194 |
+
'left_high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
195 |
+
'left_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
|
196 |
+
'right_low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
197 |
+
'right_high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
198 |
+
'right_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
|
199 |
+
'cross_low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
200 |
+
'cross_high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
201 |
+
'cross_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
|
202 |
+
'ret_low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
203 |
+
'ret_high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
204 |
+
'ret_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
|
205 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
206 |
+
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
207 |
+
'sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
|
208 |
+
'left_x_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
|
209 |
+
'right_x_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
|
210 |
+
'phase': (Stage.HINT, Location.GRAPH, Type.CATEGORICAL)
|
211 |
+
},
|
212 |
+
'find_maximum_subarray_kadane': {
|
213 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
214 |
+
'key': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
215 |
+
'start': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
|
216 |
+
'end': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
|
217 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
218 |
+
'best_low': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
219 |
+
'best_high': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
220 |
+
'best_sum': (Stage.HINT, Location.GRAPH, Type.SCALAR),
|
221 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
222 |
+
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
223 |
+
'sum': (Stage.HINT, Location.GRAPH, Type.SCALAR)
|
224 |
+
},
|
225 |
+
'matrix_chain_order': {
|
226 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
227 |
+
'p': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
228 |
+
's': (Stage.OUTPUT, Location.EDGE, Type.POINTER),
|
229 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
230 |
+
'm': (Stage.HINT, Location.EDGE, Type.SCALAR),
|
231 |
+
's_h': (Stage.HINT, Location.EDGE, Type.POINTER),
|
232 |
+
'msk': (Stage.HINT, Location.EDGE, Type.MASK)
|
233 |
+
},
|
234 |
+
'lcs_length': {
|
235 |
+
'string': (Stage.INPUT, Location.NODE, Type.MASK),
|
236 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
237 |
+
'key': (Stage.INPUT, Location.NODE, Type.CATEGORICAL),
|
238 |
+
'b': (Stage.OUTPUT, Location.EDGE, Type.CATEGORICAL),
|
239 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
240 |
+
'b_h': (Stage.HINT, Location.EDGE, Type.CATEGORICAL),
|
241 |
+
'c': (Stage.HINT, Location.EDGE, Type.SCALAR)
|
242 |
+
},
|
243 |
+
'optimal_bst': {
|
244 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
245 |
+
'p': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
246 |
+
'q': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
247 |
+
'root': (Stage.OUTPUT, Location.EDGE, Type.POINTER),
|
248 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
249 |
+
'root_h': (Stage.HINT, Location.EDGE, Type.POINTER),
|
250 |
+
'e': (Stage.HINT, Location.EDGE, Type.SCALAR),
|
251 |
+
'w': (Stage.HINT, Location.EDGE, Type.SCALAR),
|
252 |
+
'msk': (Stage.HINT, Location.EDGE, Type.MASK)
|
253 |
+
},
|
254 |
+
'activity_selector': {
|
255 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
256 |
+
's': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
257 |
+
'f': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
258 |
+
'selected': (Stage.OUTPUT, Location.NODE, Type.MASK),
|
259 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
260 |
+
'selected_h': (Stage.HINT, Location.NODE, Type.MASK),
|
261 |
+
'm': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
262 |
+
'k': (Stage.HINT, Location.NODE, Type.MASK_ONE)
|
263 |
+
},
|
264 |
+
'task_scheduling': {
|
265 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
266 |
+
'd': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
267 |
+
'w': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
268 |
+
'selected': (Stage.OUTPUT, Location.NODE, Type.MASK),
|
269 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
270 |
+
'selected_h': (Stage.HINT, Location.NODE, Type.MASK),
|
271 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
272 |
+
't': (Stage.HINT, Location.GRAPH, Type.SCALAR)
|
273 |
+
},
|
274 |
+
'dfs': {
|
275 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
276 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
277 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
278 |
+
'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
|
279 |
+
'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
280 |
+
'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
|
281 |
+
'd': (Stage.HINT, Location.NODE, Type.SCALAR),
|
282 |
+
'f': (Stage.HINT, Location.NODE, Type.SCALAR),
|
283 |
+
's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
|
284 |
+
's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
285 |
+
'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
286 |
+
'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
287 |
+
's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
288 |
+
'time': (Stage.HINT, Location.GRAPH, Type.SCALAR)
|
289 |
+
},
|
290 |
+
'topological_sort': {
|
291 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
292 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
293 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
294 |
+
'topo': (Stage.OUTPUT, Location.NODE, Type.POINTER),
|
295 |
+
'topo_head': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
|
296 |
+
'topo_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
297 |
+
'topo_head_h': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
298 |
+
'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
|
299 |
+
's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
|
300 |
+
's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
301 |
+
'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
302 |
+
'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
303 |
+
's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE)
|
304 |
+
},
|
305 |
+
'strongly_connected_components': {
|
306 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
307 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
308 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
309 |
+
'scc_id': (Stage.OUTPUT, Location.NODE, Type.POINTER),
|
310 |
+
'scc_id_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
311 |
+
'A_t': (Stage.HINT, Location.EDGE, Type.MASK),
|
312 |
+
'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
|
313 |
+
'd': (Stage.HINT, Location.NODE, Type.SCALAR),
|
314 |
+
'f': (Stage.HINT, Location.NODE, Type.SCALAR),
|
315 |
+
's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
|
316 |
+
's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
317 |
+
'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
318 |
+
'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
319 |
+
's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
320 |
+
'time': (Stage.HINT, Location.GRAPH, Type.SCALAR),
|
321 |
+
'phase': (Stage.HINT, Location.GRAPH, Type.MASK)
|
322 |
+
},
|
323 |
+
'articulation_points': {
|
324 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
325 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
326 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
327 |
+
'is_cut': (Stage.OUTPUT, Location.NODE, Type.MASK),
|
328 |
+
'is_cut_h': (Stage.HINT, Location.NODE, Type.MASK),
|
329 |
+
'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
330 |
+
'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
|
331 |
+
'd': (Stage.HINT, Location.NODE, Type.SCALAR),
|
332 |
+
'f': (Stage.HINT, Location.NODE, Type.SCALAR),
|
333 |
+
'low': (Stage.HINT, Location.NODE, Type.SCALAR),
|
334 |
+
'child_cnt': (Stage.HINT, Location.NODE, Type.SCALAR),
|
335 |
+
's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
|
336 |
+
's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
337 |
+
'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
338 |
+
'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
339 |
+
's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
340 |
+
'time': (Stage.HINT, Location.GRAPH, Type.SCALAR)
|
341 |
+
},
|
342 |
+
'bridges': {
|
343 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
344 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
345 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
346 |
+
'is_bridge': (Stage.OUTPUT, Location.EDGE, Type.MASK),
|
347 |
+
'is_bridge_h': (Stage.HINT, Location.EDGE, Type.MASK),
|
348 |
+
'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
349 |
+
'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
|
350 |
+
'd': (Stage.HINT, Location.NODE, Type.SCALAR),
|
351 |
+
'f': (Stage.HINT, Location.NODE, Type.SCALAR),
|
352 |
+
'low': (Stage.HINT, Location.NODE, Type.SCALAR),
|
353 |
+
's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
|
354 |
+
's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
355 |
+
'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
356 |
+
'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
357 |
+
's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
358 |
+
'time': (Stage.HINT, Location.GRAPH, Type.SCALAR)
|
359 |
+
},
|
360 |
+
'bfs': {
|
361 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
362 |
+
's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
|
363 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
364 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
365 |
+
'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
|
366 |
+
'reach_h': (Stage.HINT, Location.NODE, Type.MASK),
|
367 |
+
'pi_h': (Stage.HINT, Location.NODE, Type.POINTER)
|
368 |
+
},
|
369 |
+
'mst_kruskal': {
|
370 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
371 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
372 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
373 |
+
'in_mst': (Stage.OUTPUT, Location.EDGE, Type.MASK),
|
374 |
+
'in_mst_h': (Stage.HINT, Location.EDGE, Type.MASK),
|
375 |
+
'pi': (Stage.HINT, Location.NODE, Type.POINTER),
|
376 |
+
'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
377 |
+
'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
378 |
+
'root_u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
379 |
+
'root_v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
380 |
+
'mask_u': (Stage.HINT, Location.NODE, Type.MASK),
|
381 |
+
'mask_v': (Stage.HINT, Location.NODE, Type.MASK),
|
382 |
+
'phase': (Stage.HINT, Location.GRAPH, Type.CATEGORICAL)
|
383 |
+
},
|
384 |
+
'mst_prim': {
|
385 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
386 |
+
's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
|
387 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
388 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
389 |
+
'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
|
390 |
+
'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
391 |
+
'key': (Stage.HINT, Location.NODE, Type.SCALAR),
|
392 |
+
'mark': (Stage.HINT, Location.NODE, Type.MASK),
|
393 |
+
'in_queue': (Stage.HINT, Location.NODE, Type.MASK),
|
394 |
+
'u': (Stage.HINT, Location.NODE, Type.MASK_ONE)
|
395 |
+
},
|
396 |
+
'bellman_ford': {
|
397 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
398 |
+
's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
|
399 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
400 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
401 |
+
'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
|
402 |
+
'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
403 |
+
'd': (Stage.HINT, Location.NODE, Type.SCALAR),
|
404 |
+
'msk': (Stage.HINT, Location.NODE, Type.MASK)
|
405 |
+
},
|
406 |
+
'dag_shortest_paths': {
|
407 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
408 |
+
's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
|
409 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
410 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
411 |
+
'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
|
412 |
+
'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
413 |
+
'd': (Stage.HINT, Location.NODE, Type.SCALAR),
|
414 |
+
'mark': (Stage.HINT, Location.NODE, Type.MASK),
|
415 |
+
'topo_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
416 |
+
'topo_head_h': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
417 |
+
'color': (Stage.HINT, Location.NODE, Type.CATEGORICAL),
|
418 |
+
's_prev': (Stage.HINT, Location.NODE, Type.POINTER),
|
419 |
+
'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
420 |
+
'v': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
421 |
+
's_last': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
422 |
+
'phase': (Stage.HINT, Location.GRAPH, Type.MASK)
|
423 |
+
},
|
424 |
+
'dijkstra': {
|
425 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
426 |
+
's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
|
427 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
428 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
429 |
+
'pi': (Stage.OUTPUT, Location.NODE, Type.POINTER),
|
430 |
+
'pi_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
431 |
+
'd': (Stage.HINT, Location.NODE, Type.SCALAR),
|
432 |
+
'mark': (Stage.HINT, Location.NODE, Type.MASK),
|
433 |
+
'in_queue': (Stage.HINT, Location.NODE, Type.MASK),
|
434 |
+
'u': (Stage.HINT, Location.NODE, Type.MASK_ONE)
|
435 |
+
},
|
436 |
+
'floyd_warshall': {
|
437 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
438 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
439 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
440 |
+
'Pi': (Stage.OUTPUT, Location.EDGE, Type.POINTER),
|
441 |
+
'Pi_h': (Stage.HINT, Location.EDGE, Type.POINTER),
|
442 |
+
'D': (Stage.HINT, Location.EDGE, Type.SCALAR),
|
443 |
+
'msk': (Stage.HINT, Location.EDGE, Type.MASK),
|
444 |
+
'k': (Stage.HINT, Location.NODE, Type.MASK_ONE)
|
445 |
+
},
|
446 |
+
'bipartite_matching': {
|
447 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
448 |
+
'A': (Stage.INPUT, Location.EDGE, Type.SCALAR),
|
449 |
+
'adj': (Stage.INPUT, Location.EDGE, Type.MASK),
|
450 |
+
's': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
|
451 |
+
't': (Stage.INPUT, Location.NODE, Type.MASK_ONE),
|
452 |
+
'in_matching': (Stage.OUTPUT, Location.EDGE, Type.MASK),
|
453 |
+
'in_matching_h': (Stage.HINT, Location.EDGE, Type.MASK),
|
454 |
+
'A_h': (Stage.HINT, Location.EDGE, Type.SCALAR),
|
455 |
+
'adj_h': (Stage.HINT, Location.EDGE, Type.MASK),
|
456 |
+
'd': (Stage.HINT, Location.NODE, Type.SCALAR),
|
457 |
+
'msk': (Stage.HINT, Location.NODE, Type.MASK),
|
458 |
+
'pi': (Stage.HINT, Location.NODE, Type.POINTER),
|
459 |
+
'u': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
460 |
+
'phase': (Stage.HINT, Location.GRAPH, Type.MASK)
|
461 |
+
},
|
462 |
+
'naive_string_matcher': {
|
463 |
+
'string': (Stage.INPUT, Location.NODE, Type.MASK),
|
464 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
465 |
+
'key': (Stage.INPUT, Location.NODE, Type.CATEGORICAL),
|
466 |
+
'match': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
|
467 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
468 |
+
's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
469 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
470 |
+
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE)
|
471 |
+
},
|
472 |
+
'kmp_matcher': {
|
473 |
+
'string': (Stage.INPUT, Location.NODE, Type.MASK),
|
474 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
475 |
+
'key': (Stage.INPUT, Location.NODE, Type.CATEGORICAL),
|
476 |
+
'match': (Stage.OUTPUT, Location.NODE, Type.MASK_ONE),
|
477 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
478 |
+
'pi': (Stage.HINT, Location.NODE, Type.POINTER),
|
479 |
+
'is_reset': (Stage.HINT, Location.NODE, Type.MASK),
|
480 |
+
'k': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
481 |
+
'k_reset': (Stage.HINT, Location.GRAPH, Type.MASK),
|
482 |
+
'q': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
483 |
+
'q_reset': (Stage.HINT, Location.GRAPH, Type.MASK),
|
484 |
+
's': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
485 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
486 |
+
'phase': (Stage.HINT, Location.GRAPH, Type.MASK)
|
487 |
+
},
|
488 |
+
'segments_intersect': {
|
489 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
490 |
+
'x': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
491 |
+
'y': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
492 |
+
'intersect': (Stage.OUTPUT, Location.GRAPH, Type.MASK),
|
493 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
494 |
+
'j': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
495 |
+
'k': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
496 |
+
'dir': (Stage.HINT, Location.NODE, Type.SCALAR),
|
497 |
+
'on_seg': (Stage.HINT, Location.NODE, Type.MASK)
|
498 |
+
},
|
499 |
+
'graham_scan': {
|
500 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
501 |
+
'x': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
502 |
+
'y': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
503 |
+
'in_hull': (Stage.OUTPUT, Location.NODE, Type.MASK),
|
504 |
+
'best': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
505 |
+
'atans': (Stage.HINT, Location.NODE, Type.SCALAR),
|
506 |
+
'in_hull_h': (Stage.HINT, Location.NODE, Type.MASK),
|
507 |
+
'stack_prev': (Stage.HINT, Location.NODE, Type.POINTER),
|
508 |
+
'last_stack': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
509 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
510 |
+
'phase': (Stage.HINT, Location.GRAPH, Type.CATEGORICAL)
|
511 |
+
},
|
512 |
+
'jarvis_march': {
|
513 |
+
'pos': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
514 |
+
'x': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
515 |
+
'y': (Stage.INPUT, Location.NODE, Type.SCALAR),
|
516 |
+
'in_hull': (Stage.OUTPUT, Location.NODE, Type.MASK),
|
517 |
+
'pred_h': (Stage.HINT, Location.NODE, Type.POINTER),
|
518 |
+
'in_hull_h': (Stage.HINT, Location.NODE, Type.MASK),
|
519 |
+
'best': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
520 |
+
'last_point': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
521 |
+
'endpoint': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
522 |
+
'i': (Stage.HINT, Location.NODE, Type.MASK_ONE),
|
523 |
+
'phase': (Stage.HINT, Location.GRAPH, Type.CATEGORICAL)
|
524 |
+
}
|
525 |
+
})
|
benchmarks/CLRS/env/train.py
ADDED
@@ -0,0 +1,560 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Run training of one or more algorithmic tasks from CLRS."""
|
17 |
+
import os
|
18 |
+
# disable logging until training starts
|
19 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
20 |
+
|
21 |
+
import functools
|
22 |
+
import os
|
23 |
+
import shutil
|
24 |
+
from typing import Any, Dict, List
|
25 |
+
|
26 |
+
from absl import app
|
27 |
+
from absl import flags
|
28 |
+
from absl import logging
|
29 |
+
# disable logging until training starts
|
30 |
+
logging.set_verbosity(logging.ERROR)
|
31 |
+
|
32 |
+
import clrs
|
33 |
+
import jax
|
34 |
+
import numpy as np
|
35 |
+
import requests
|
36 |
+
import tensorflow as tf
|
37 |
+
from baselines import BaselineModel, BaselineModelChunked
|
38 |
+
import pickle
|
39 |
+
import copy
|
40 |
+
|
41 |
+
flags.DEFINE_list('algorithms', ['floyd_warshall'], 'Which algorithms to run.')
|
42 |
+
flags.DEFINE_list('train_lengths', ['4', '7', '11', '13', '16'],
|
43 |
+
'Which training sizes to use. A size of -1 means '
|
44 |
+
'use the benchmark dataset.')
|
45 |
+
flags.DEFINE_integer('length_needle', -8,
|
46 |
+
'Length of needle for training and validation '
|
47 |
+
'(not testing) in string matching algorithms. '
|
48 |
+
'A negative value randomizes the length for each sample '
|
49 |
+
'between 1 and the opposite of the value. '
|
50 |
+
'A value of 0 means use always 1/4 of the length of '
|
51 |
+
'the haystack (the default sampler behavior).')
|
52 |
+
flags.DEFINE_integer('seed', 42, 'Random seed to set')
|
53 |
+
|
54 |
+
flags.DEFINE_boolean('random_pos', True,
|
55 |
+
'Randomize the pos input common to all algos.')
|
56 |
+
flags.DEFINE_boolean('enforce_permutations', True,
|
57 |
+
'Whether to enforce permutation-type node pointers.')
|
58 |
+
flags.DEFINE_boolean('enforce_pred_as_input', True,
|
59 |
+
'Whether to change pred_h hints into pred inputs.')
|
60 |
+
flags.DEFINE_integer('batch_size', 32, 'Batch size used for training.')
|
61 |
+
flags.DEFINE_boolean('chunked_training', False,
|
62 |
+
'Whether to use chunking for training.')
|
63 |
+
flags.DEFINE_integer('chunk_length', 16,
|
64 |
+
'Time chunk length used for training (if '
|
65 |
+
'`chunked_training` is True.')
|
66 |
+
flags.DEFINE_integer('train_steps', 500, 'Number of training iterations.')
|
67 |
+
flags.DEFINE_integer('eval_every', 50, 'Evaluation frequency (in steps).')
|
68 |
+
flags.DEFINE_integer('test_every', 500, 'Evaluation frequency (in steps).')
|
69 |
+
flags.DEFINE_integer('log_every', 50, 'Logging frequency (in steps).')
|
70 |
+
|
71 |
+
flags.DEFINE_integer('hidden_size', 128,
|
72 |
+
'Number of hidden units of the model.')
|
73 |
+
flags.DEFINE_integer('nb_heads', 1, 'Number of heads for GAT processors')
|
74 |
+
flags.DEFINE_integer('nb_msg_passing_steps', 1,
|
75 |
+
'Number of message passing steps to run per hint.')
|
76 |
+
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate to use.')
|
77 |
+
flags.DEFINE_float('grad_clip_max_norm', 1.0,
|
78 |
+
'Gradient clipping by norm. 0.0 disables grad clipping')
|
79 |
+
flags.DEFINE_float('dropout_prob', 0.0, 'Dropout rate to use.')
|
80 |
+
flags.DEFINE_float('hint_teacher_forcing', 0.0,
|
81 |
+
'Probability that ground-truth teacher hints are encoded '
|
82 |
+
'during training instead of predicted hints. Only '
|
83 |
+
'pertinent in encoded_decoded modes.')
|
84 |
+
flags.DEFINE_enum('hint_mode', 'encoded_decoded',
|
85 |
+
['encoded_decoded', 'decoded_only', 'none'],
|
86 |
+
'How should hints be used? Note, each mode defines a '
|
87 |
+
'separate task, with various difficulties. `encoded_decoded` '
|
88 |
+
'requires the model to explicitly materialise hint sequences '
|
89 |
+
'and therefore is hardest, but also most aligned to the '
|
90 |
+
'underlying algorithmic rule. Hence, `encoded_decoded` '
|
91 |
+
'should be treated as the default mode for our benchmark. '
|
92 |
+
'In `decoded_only`, hints are only used for defining '
|
93 |
+
'reconstruction losses. Often, this will perform well, but '
|
94 |
+
'note that we currently do not make any efforts to '
|
95 |
+
'counterbalance the various hint losses. Hence, for certain '
|
96 |
+
'tasks, the best performance will now be achievable with no '
|
97 |
+
'hint usage at all (`none`).')
|
98 |
+
flags.DEFINE_enum('hint_repred_mode', 'soft', ['soft', 'hard', 'hard_on_eval'],
|
99 |
+
'How to process predicted hints when fed back as inputs.'
|
100 |
+
'In soft mode, we use softmaxes for categoricals, pointers '
|
101 |
+
'and mask_one, and sigmoids for masks. '
|
102 |
+
'In hard mode, we use argmax instead of softmax, and hard '
|
103 |
+
'thresholding of masks. '
|
104 |
+
'In hard_on_eval mode, soft mode is '
|
105 |
+
'used for training and hard mode is used for evaluation.')
|
106 |
+
flags.DEFINE_boolean('use_ln', True,
|
107 |
+
'Whether to use layer normalisation in the processor.')
|
108 |
+
flags.DEFINE_boolean('use_lstm', False,
|
109 |
+
'Whether to insert an LSTM after message passing.')
|
110 |
+
flags.DEFINE_integer('nb_triplet_fts', 8,
|
111 |
+
'How many triplet features to compute?')
|
112 |
+
|
113 |
+
flags.DEFINE_enum('encoder_init', 'xavier_on_scalars',
|
114 |
+
['default', 'xavier_on_scalars'],
|
115 |
+
'Initialiser to use for the encoders.')
|
116 |
+
flags.DEFINE_enum('processor_type', 'triplet_gmpnn',
|
117 |
+
['deepsets', 'mpnn', 'pgn', 'pgn_mask',
|
118 |
+
'triplet_mpnn', 'triplet_pgn', 'triplet_pgn_mask',
|
119 |
+
'gat', 'gatv2', 'gat_full', 'gatv2_full',
|
120 |
+
'gpgn', 'gpgn_mask', 'gmpnn',
|
121 |
+
'triplet_gpgn', 'triplet_gpgn_mask', 'triplet_gmpnn'],
|
122 |
+
'Processor type to use as the network P.')
|
123 |
+
|
124 |
+
flags.DEFINE_string('checkpoint_path', './checkpoints',
|
125 |
+
'Path in which checkpoints are saved.')
|
126 |
+
flags.DEFINE_string('dataset_path', '/tmp/CLRS30',
|
127 |
+
'Path in which dataset is stored.')
|
128 |
+
flags.DEFINE_boolean('freeze_processor', False,
|
129 |
+
'Whether to freeze the processor of the model.')
|
130 |
+
|
131 |
+
FLAGS = flags.FLAGS
|
132 |
+
|
133 |
+
|
134 |
+
PRED_AS_INPUT_ALGOS = [
|
135 |
+
'binary_search',
|
136 |
+
'minimum',
|
137 |
+
'find_maximum_subarray',
|
138 |
+
'find_maximum_subarray_kadane',
|
139 |
+
'matrix_chain_order',
|
140 |
+
'lcs_length',
|
141 |
+
'optimal_bst',
|
142 |
+
'activity_selector',
|
143 |
+
'task_scheduling',
|
144 |
+
'naive_string_matcher',
|
145 |
+
'kmp_matcher',
|
146 |
+
'jarvis_march']
|
147 |
+
|
148 |
+
|
149 |
+
def unpack(v):
|
150 |
+
try:
|
151 |
+
return v.item() # DeviceArray
|
152 |
+
except (AttributeError, ValueError):
|
153 |
+
return v
|
154 |
+
|
155 |
+
|
156 |
+
def _iterate_sampler(sampler, batch_size):
|
157 |
+
while True:
|
158 |
+
yield sampler.next(batch_size)
|
159 |
+
|
160 |
+
|
161 |
+
def _maybe_download_dataset(dataset_path):
|
162 |
+
"""Download CLRS30 dataset if needed."""
|
163 |
+
dataset_folder = os.path.join(dataset_path, clrs.get_clrs_folder())
|
164 |
+
if os.path.isdir(dataset_folder):
|
165 |
+
logging.info('Dataset found at %s. Skipping download.', dataset_folder)
|
166 |
+
return dataset_folder
|
167 |
+
logging.info('Dataset not found in %s. Downloading...', dataset_folder)
|
168 |
+
|
169 |
+
clrs_url = clrs.get_dataset_gcp_url()
|
170 |
+
request = requests.get(clrs_url, allow_redirects=True)
|
171 |
+
clrs_file = os.path.join(dataset_path, os.path.basename(clrs_url))
|
172 |
+
os.makedirs(dataset_folder)
|
173 |
+
open(clrs_file, 'wb').write(request.content)
|
174 |
+
shutil.unpack_archive(clrs_file, extract_dir=dataset_folder)
|
175 |
+
os.remove(clrs_file)
|
176 |
+
return dataset_folder
|
177 |
+
|
178 |
+
|
179 |
+
def make_sampler(length: int,
|
180 |
+
rng: Any,
|
181 |
+
algorithm: str,
|
182 |
+
split: str,
|
183 |
+
batch_size: int,
|
184 |
+
multiplier: int,
|
185 |
+
randomize_pos: bool,
|
186 |
+
enforce_pred_as_input: bool,
|
187 |
+
enforce_permutations: bool,
|
188 |
+
chunked: bool,
|
189 |
+
chunk_length: int,
|
190 |
+
sampler_kwargs: Dict[str, Any]):
|
191 |
+
"""Create a sampler with given options.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
length: Size of samples (i.e., number of nodes in the graph).
|
195 |
+
A length of -1 will mean that the benchmark
|
196 |
+
dataset (for the given split) is used. Positive sizes will instantiate
|
197 |
+
samplers of the corresponding size.
|
198 |
+
rng: Numpy random state.
|
199 |
+
algorithm: The name of the algorithm to sample from.
|
200 |
+
split: 'train', 'val' or 'test'.
|
201 |
+
batch_size: Samples per batch.
|
202 |
+
multiplier: Integer multiplier for the number of samples in the dataset,
|
203 |
+
only used for positive sizes. Negative multiplier means infinite samples.
|
204 |
+
randomize_pos: Whether to randomize the `pos` input.
|
205 |
+
enforce_pred_as_input: Whether to convert fixed pred_h hints to inputs.
|
206 |
+
enforce_permutations: Whether to enforce permutation pointers.
|
207 |
+
chunked: Whether to chunk the dataset.
|
208 |
+
chunk_length: Unroll length of chunks, if `chunked` is True.
|
209 |
+
sampler_kwargs: Extra args passed to the sampler.
|
210 |
+
Returns:
|
211 |
+
A sampler (iterator), the number of samples in the iterator (negative
|
212 |
+
if infinite samples), and the spec.
|
213 |
+
"""
|
214 |
+
if length < 0: # load from file
|
215 |
+
dataset_folder = _maybe_download_dataset(FLAGS.dataset_path)
|
216 |
+
sampler, num_samples, spec = clrs.create_dataset(folder=dataset_folder,
|
217 |
+
algorithm=algorithm,
|
218 |
+
batch_size=batch_size,
|
219 |
+
split=split)
|
220 |
+
sampler = sampler.as_numpy_iterator()
|
221 |
+
else:
|
222 |
+
num_samples = clrs.CLRS30[split]['num_samples'] * multiplier
|
223 |
+
sampler, spec = clrs.build_sampler(
|
224 |
+
algorithm,
|
225 |
+
seed=rng.randint(2**32),
|
226 |
+
num_samples=num_samples,
|
227 |
+
length=length,
|
228 |
+
**sampler_kwargs,
|
229 |
+
)
|
230 |
+
sampler = _iterate_sampler(sampler, batch_size)
|
231 |
+
|
232 |
+
if randomize_pos:
|
233 |
+
sampler = clrs.process_random_pos(sampler, rng)
|
234 |
+
if enforce_pred_as_input and algorithm in PRED_AS_INPUT_ALGOS:
|
235 |
+
spec, sampler = clrs.process_pred_as_input(spec, sampler)
|
236 |
+
spec, sampler = clrs.process_permutations(spec, sampler, enforce_permutations)
|
237 |
+
if chunked:
|
238 |
+
sampler = clrs.chunkify(sampler, chunk_length)
|
239 |
+
return sampler, num_samples, spec
|
240 |
+
|
241 |
+
|
242 |
+
def make_multi_sampler(sizes, rng, **kwargs):
|
243 |
+
"""Create a sampler with cycling sample sizes."""
|
244 |
+
ss = []
|
245 |
+
tot_samples = 0
|
246 |
+
for length in sizes:
|
247 |
+
sampler, num_samples, spec = make_sampler(length, rng, **kwargs)
|
248 |
+
ss.append(sampler)
|
249 |
+
tot_samples += num_samples
|
250 |
+
|
251 |
+
def cycle_samplers():
|
252 |
+
while True:
|
253 |
+
for s in ss:
|
254 |
+
yield next(s)
|
255 |
+
return cycle_samplers(), tot_samples, spec
|
256 |
+
|
257 |
+
|
258 |
+
def _concat(dps, axis):
|
259 |
+
return jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis), *dps)
|
260 |
+
|
261 |
+
|
262 |
+
def collect_and_eval(sampler, predict_fn, sample_count, rng_key, extras):
|
263 |
+
"""Collect batches of output and hint preds and evaluate them."""
|
264 |
+
processed_samples = 0
|
265 |
+
preds = []
|
266 |
+
outputs = []
|
267 |
+
while processed_samples < sample_count:
|
268 |
+
feedback = next(sampler)
|
269 |
+
batch_size = feedback.outputs[0].data.shape[0]
|
270 |
+
outputs.append(feedback.outputs)
|
271 |
+
new_rng_key, rng_key = jax.random.split(rng_key)
|
272 |
+
cur_preds, _ = predict_fn(new_rng_key, feedback.features)
|
273 |
+
preds.append(cur_preds)
|
274 |
+
processed_samples += batch_size
|
275 |
+
outputs = _concat(outputs, axis=0)
|
276 |
+
preds = _concat(preds, axis=0)
|
277 |
+
out = clrs.evaluate(outputs, preds)
|
278 |
+
if extras:
|
279 |
+
out.update(extras)
|
280 |
+
return {k: unpack(v) for k, v in out.items()}
|
281 |
+
|
282 |
+
|
283 |
+
def create_samplers(rng, train_lengths: List[int]):
|
284 |
+
"""Create all the samplers."""
|
285 |
+
train_samplers = []
|
286 |
+
val_samplers = []
|
287 |
+
val_sample_counts = []
|
288 |
+
test_samplers = []
|
289 |
+
test_sample_counts = []
|
290 |
+
spec_list = []
|
291 |
+
|
292 |
+
for algo_idx, algorithm in enumerate(FLAGS.algorithms):
|
293 |
+
# Make full dataset pipeline run on CPU (including prefetching).
|
294 |
+
with tf.device('/cpu:0'):
|
295 |
+
|
296 |
+
if algorithm in ['naive_string_matcher', 'kmp_matcher']:
|
297 |
+
# Fixed haystack + needle; variability will be in needle
|
298 |
+
# Still, for chunked training, we maintain as many samplers
|
299 |
+
# as train lengths, since, for each length there is a separate state,
|
300 |
+
# and we must keep the 1:1 relationship between states and samplers.
|
301 |
+
max_length = max(train_lengths)
|
302 |
+
if max_length > 0: # if < 0, we are using the benchmark data
|
303 |
+
max_length = (max_length * 5) // 4
|
304 |
+
train_lengths = [max_length]
|
305 |
+
if FLAGS.chunked_training:
|
306 |
+
train_lengths = train_lengths * len(train_lengths)
|
307 |
+
|
308 |
+
logging.info('Creating samplers for algo %s', algorithm)
|
309 |
+
|
310 |
+
p = tuple([0.1 + 0.1 * i for i in range(9)])
|
311 |
+
if p and algorithm in ['articulation_points', 'bridges',
|
312 |
+
'mst_kruskal', 'bipartite_matching']:
|
313 |
+
# Choose a lower connection probability for the above algorithms,
|
314 |
+
# otherwise trajectories are very long
|
315 |
+
p = tuple(np.array(p) / 2)
|
316 |
+
length_needle = FLAGS.length_needle
|
317 |
+
sampler_kwargs = dict(p=p, length_needle=length_needle)
|
318 |
+
if length_needle == 0:
|
319 |
+
sampler_kwargs.pop('length_needle')
|
320 |
+
|
321 |
+
common_sampler_args = dict(
|
322 |
+
algorithm=FLAGS.algorithms[algo_idx],
|
323 |
+
rng=rng,
|
324 |
+
enforce_pred_as_input=FLAGS.enforce_pred_as_input,
|
325 |
+
enforce_permutations=FLAGS.enforce_permutations,
|
326 |
+
chunk_length=FLAGS.chunk_length,
|
327 |
+
)
|
328 |
+
|
329 |
+
train_args = dict(sizes=train_lengths,
|
330 |
+
split='train',
|
331 |
+
batch_size=FLAGS.batch_size,
|
332 |
+
multiplier=-1,
|
333 |
+
randomize_pos=FLAGS.random_pos,
|
334 |
+
chunked=FLAGS.chunked_training,
|
335 |
+
sampler_kwargs=sampler_kwargs,
|
336 |
+
**common_sampler_args)
|
337 |
+
train_sampler, _, spec = make_multi_sampler(**train_args)
|
338 |
+
|
339 |
+
mult = clrs.CLRS_30_ALGS_SETTINGS[algorithm]['num_samples_multiplier']
|
340 |
+
val_args = dict(sizes=[np.amax(train_lengths)],
|
341 |
+
split='val',
|
342 |
+
batch_size=32,
|
343 |
+
multiplier=2 * mult,
|
344 |
+
randomize_pos=FLAGS.random_pos,
|
345 |
+
chunked=False,
|
346 |
+
sampler_kwargs=sampler_kwargs,
|
347 |
+
**common_sampler_args)
|
348 |
+
val_sampler, val_samples, spec = make_multi_sampler(**val_args)
|
349 |
+
|
350 |
+
test_args = dict(sizes=[-1],
|
351 |
+
split='test',
|
352 |
+
batch_size=32,
|
353 |
+
multiplier=2 * mult,
|
354 |
+
randomize_pos=False,
|
355 |
+
chunked=False,
|
356 |
+
sampler_kwargs={},
|
357 |
+
**common_sampler_args)
|
358 |
+
test_sampler, test_samples, spec = make_multi_sampler(**test_args)
|
359 |
+
|
360 |
+
spec_list.append(spec)
|
361 |
+
train_samplers.append(train_sampler)
|
362 |
+
val_samplers.append(val_sampler)
|
363 |
+
val_sample_counts.append(val_samples)
|
364 |
+
test_samplers.append(test_sampler)
|
365 |
+
test_sample_counts.append(test_samples)
|
366 |
+
|
367 |
+
return (train_samplers,
|
368 |
+
val_samplers, val_sample_counts,
|
369 |
+
test_samplers, test_sample_counts,
|
370 |
+
spec_list)
|
371 |
+
|
372 |
+
|
373 |
+
def main(unused_argv):
|
374 |
+
if FLAGS.hint_mode == 'encoded_decoded':
|
375 |
+
encode_hints = True
|
376 |
+
decode_hints = True
|
377 |
+
elif FLAGS.hint_mode == 'decoded_only':
|
378 |
+
encode_hints = False
|
379 |
+
decode_hints = True
|
380 |
+
elif FLAGS.hint_mode == 'none':
|
381 |
+
encode_hints = False
|
382 |
+
decode_hints = False
|
383 |
+
else:
|
384 |
+
raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.')
|
385 |
+
|
386 |
+
train_lengths = [int(x) for x in FLAGS.train_lengths]
|
387 |
+
|
388 |
+
rng = np.random.RandomState(FLAGS.seed)
|
389 |
+
rng_key = jax.random.PRNGKey(rng.randint(2**32))
|
390 |
+
|
391 |
+
# Create samplers
|
392 |
+
(train_samplers,
|
393 |
+
val_samplers, val_sample_counts,
|
394 |
+
test_samplers, test_sample_counts,
|
395 |
+
spec_list) = create_samplers(rng, train_lengths)
|
396 |
+
|
397 |
+
processor_factory = clrs.get_processor_factory(
|
398 |
+
FLAGS.processor_type,
|
399 |
+
use_ln=FLAGS.use_ln,
|
400 |
+
nb_triplet_fts=FLAGS.nb_triplet_fts,
|
401 |
+
nb_heads=FLAGS.nb_heads
|
402 |
+
)
|
403 |
+
model_params = dict(
|
404 |
+
processor_factory=processor_factory,
|
405 |
+
hidden_dim=FLAGS.hidden_size,
|
406 |
+
encode_hints=encode_hints,
|
407 |
+
decode_hints=decode_hints,
|
408 |
+
encoder_init=FLAGS.encoder_init,
|
409 |
+
use_lstm=FLAGS.use_lstm,
|
410 |
+
learning_rate=FLAGS.learning_rate,
|
411 |
+
grad_clip_max_norm=FLAGS.grad_clip_max_norm,
|
412 |
+
checkpoint_path=FLAGS.checkpoint_path,
|
413 |
+
freeze_processor=FLAGS.freeze_processor,
|
414 |
+
dropout_prob=FLAGS.dropout_prob,
|
415 |
+
hint_teacher_forcing=FLAGS.hint_teacher_forcing,
|
416 |
+
hint_repred_mode=FLAGS.hint_repred_mode,
|
417 |
+
nb_msg_passing_steps=FLAGS.nb_msg_passing_steps,
|
418 |
+
)
|
419 |
+
|
420 |
+
# save spec_list and model_params; do not change or delete!!
|
421 |
+
if not os.path.exists(FLAGS.checkpoint_path):
|
422 |
+
os.makedirs(FLAGS.checkpoint_path)
|
423 |
+
|
424 |
+
with open(os.path.join(FLAGS.checkpoint_path, 'spec_list.pkl'), 'wb') as f:
|
425 |
+
pickle.dump(spec_list, f)
|
426 |
+
model_params_save = copy.deepcopy(model_params)
|
427 |
+
model_params_save["processor_factory"] = (FLAGS.processor_type, FLAGS.use_ln, FLAGS.nb_triplet_fts, FLAGS.nb_heads)
|
428 |
+
with open(os.path.join(FLAGS.checkpoint_path, 'model_params.pkl'), 'wb') as f:
|
429 |
+
pickle.dump(model_params_save, f)
|
430 |
+
|
431 |
+
eval_model = BaselineModel(
|
432 |
+
spec=spec_list,
|
433 |
+
dummy_trajectory=[next(t) for t in val_samplers],
|
434 |
+
**model_params
|
435 |
+
)
|
436 |
+
if FLAGS.chunked_training:
|
437 |
+
train_model = BaselineModelChunked(
|
438 |
+
spec=spec_list,
|
439 |
+
dummy_trajectory=[next(t) for t in train_samplers],
|
440 |
+
**model_params
|
441 |
+
)
|
442 |
+
else:
|
443 |
+
train_model = eval_model
|
444 |
+
|
445 |
+
# Training loop.
|
446 |
+
best_score = -1.0
|
447 |
+
current_train_items = [0] * len(FLAGS.algorithms)
|
448 |
+
step = 0
|
449 |
+
next_eval = 0
|
450 |
+
# Make sure scores improve on first step, but not overcome best score
|
451 |
+
# until all algos have had at least one evaluation.
|
452 |
+
val_scores = [-99999.9] * len(FLAGS.algorithms)
|
453 |
+
length_idx = 0
|
454 |
+
|
455 |
+
while step < FLAGS.train_steps:
|
456 |
+
feedback_list = [next(t) for t in train_samplers]
|
457 |
+
|
458 |
+
# Initialize model.
|
459 |
+
if step == 0:
|
460 |
+
all_features = [f.features for f in feedback_list]
|
461 |
+
if FLAGS.chunked_training:
|
462 |
+
# We need to initialize the model with samples of all lengths for
|
463 |
+
# all algorithms. Also, we need to make sure that the order of these
|
464 |
+
# sample sizes is the same as the order of the actual training sizes.
|
465 |
+
all_length_features = [all_features] + [
|
466 |
+
[next(t).features for t in train_samplers]
|
467 |
+
for _ in range(len(train_lengths))]
|
468 |
+
train_model.init(all_length_features[:-1], FLAGS.seed + 1)
|
469 |
+
else:
|
470 |
+
train_model.init(all_features, FLAGS.seed + 1)
|
471 |
+
|
472 |
+
# Training step.
|
473 |
+
# enable logging now that we have initialized the model
|
474 |
+
logging.set_verbosity(logging.INFO)
|
475 |
+
for algo_idx in range(len(train_samplers)):
|
476 |
+
feedback = feedback_list[algo_idx]
|
477 |
+
rng_key, new_rng_key = jax.random.split(rng_key)
|
478 |
+
if FLAGS.chunked_training:
|
479 |
+
# In chunked training, we must indicate which training length we are
|
480 |
+
# using, so the model uses the correct state.
|
481 |
+
length_and_algo_idx = (length_idx, algo_idx)
|
482 |
+
else:
|
483 |
+
# In non-chunked training, all training lengths can be treated equally,
|
484 |
+
# since there is no state to maintain between batches.
|
485 |
+
length_and_algo_idx = algo_idx
|
486 |
+
cur_loss = train_model.feedback(rng_key, feedback, length_and_algo_idx)
|
487 |
+
rng_key = new_rng_key
|
488 |
+
|
489 |
+
if FLAGS.chunked_training:
|
490 |
+
examples_in_chunk = np.sum(feedback.features.is_last).item()
|
491 |
+
else:
|
492 |
+
examples_in_chunk = len(feedback.features.lengths)
|
493 |
+
current_train_items[algo_idx] += examples_in_chunk
|
494 |
+
if step % FLAGS.log_every == 0:
|
495 |
+
logging.info('Algo %s step %i current loss %f, current_train_items %i.',
|
496 |
+
FLAGS.algorithms[algo_idx], step,
|
497 |
+
cur_loss, current_train_items[algo_idx])
|
498 |
+
|
499 |
+
# Periodically evaluate model
|
500 |
+
if step >= next_eval:
|
501 |
+
eval_model.params = train_model.params
|
502 |
+
for algo_idx in range(len(train_samplers)):
|
503 |
+
common_extras = {'examples_seen': current_train_items[algo_idx],
|
504 |
+
'step': step,
|
505 |
+
'algorithm': FLAGS.algorithms[algo_idx]}
|
506 |
+
|
507 |
+
# Validation info.
|
508 |
+
new_rng_key, rng_key = jax.random.split(rng_key)
|
509 |
+
val_stats = collect_and_eval(
|
510 |
+
val_samplers[algo_idx],
|
511 |
+
functools.partial(eval_model.predict, algorithm_index=algo_idx),
|
512 |
+
val_sample_counts[algo_idx],
|
513 |
+
new_rng_key,
|
514 |
+
extras=common_extras)
|
515 |
+
logging.info('(val) algo %s step %d: %s',
|
516 |
+
FLAGS.algorithms[algo_idx], step, val_stats)
|
517 |
+
val_scores[algo_idx] = val_stats['score']
|
518 |
+
|
519 |
+
next_eval += FLAGS.eval_every
|
520 |
+
|
521 |
+
# If best total score, update best checkpoint.
|
522 |
+
# Also save a best checkpoint on the first step.
|
523 |
+
msg = (f'best avg val score was '
|
524 |
+
f'{best_score/len(FLAGS.algorithms):.3f}, '
|
525 |
+
f'current avg val score is {np.mean(val_scores):.3f}, '
|
526 |
+
f'val scores are: ')
|
527 |
+
msg += ', '.join(
|
528 |
+
['%s: %.3f' % (x, y) for (x, y) in zip(FLAGS.algorithms, val_scores)])
|
529 |
+
if (sum(val_scores) > best_score) or step == 0:
|
530 |
+
best_score = sum(val_scores)
|
531 |
+
logging.info('Checkpointing best model, %s', msg)
|
532 |
+
train_model.save_model('best.pkl')
|
533 |
+
else:
|
534 |
+
logging.info('Not saving new best model, %s', msg)
|
535 |
+
|
536 |
+
step += 1
|
537 |
+
length_idx = (length_idx + 1) % len(train_lengths)
|
538 |
+
|
539 |
+
logging.info('Restoring best model from checkpoint...')
|
540 |
+
eval_model.restore_model('best.pkl', only_load_processor=False)
|
541 |
+
|
542 |
+
for algo_idx in range(len(train_samplers)):
|
543 |
+
common_extras = {'examples_seen': current_train_items[algo_idx],
|
544 |
+
'step': step,
|
545 |
+
'algorithm': FLAGS.algorithms[algo_idx]}
|
546 |
+
|
547 |
+
new_rng_key, rng_key = jax.random.split(rng_key)
|
548 |
+
test_stats = collect_and_eval(
|
549 |
+
test_samplers[algo_idx],
|
550 |
+
functools.partial(eval_model.predict, algorithm_index=algo_idx),
|
551 |
+
test_sample_counts[algo_idx],
|
552 |
+
new_rng_key,
|
553 |
+
extras=common_extras)
|
554 |
+
logging.info('(test) algo %s : %s', FLAGS.algorithms[algo_idx], test_stats)
|
555 |
+
|
556 |
+
logging.info('Done!')
|
557 |
+
|
558 |
+
|
559 |
+
if __name__ == '__main__':
|
560 |
+
app.run(main)
|
benchmarks/CLRS/scripts/eval.py
ADDED
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2022 DeepMind Technologies Limited. All Rights Reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
# ==============================================================================
|
15 |
+
|
16 |
+
"""Run training of one or more algorithmic tasks from CLRS."""
|
17 |
+
|
18 |
+
import os
|
19 |
+
# disable logging until training starts
|
20 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
21 |
+
|
22 |
+
import functools
|
23 |
+
import os
|
24 |
+
import shutil
|
25 |
+
from typing import Any, Dict, List
|
26 |
+
|
27 |
+
from absl import app
|
28 |
+
from absl import flags
|
29 |
+
from absl import logging
|
30 |
+
# disable logging until training starts
|
31 |
+
logging.set_verbosity(logging.ERROR)
|
32 |
+
|
33 |
+
import clrs
|
34 |
+
import jax
|
35 |
+
import numpy as np
|
36 |
+
import requests
|
37 |
+
import tensorflow as tf
|
38 |
+
import sys
|
39 |
+
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../env")))
|
40 |
+
from baselines import BaselineModel, BaselineModelChunked
|
41 |
+
import pickle
|
42 |
+
|
43 |
+
flags.DEFINE_list('algorithms', ['floyd_warshall'], 'Which algorithms to run.')
|
44 |
+
flags.DEFINE_list('train_lengths', ['4', '7', '11', '13', '16'],
|
45 |
+
'Which training sizes to use. A size of -1 means '
|
46 |
+
'use the benchmark dataset.')
|
47 |
+
flags.DEFINE_integer('length_needle', -8,
|
48 |
+
'Length of needle for training and validation '
|
49 |
+
'(not testing) in string matching algorithms. '
|
50 |
+
'A negative value randomizes the length for each sample '
|
51 |
+
'between 1 and the opposite of the value. '
|
52 |
+
'A value of 0 means use always 1/4 of the length of '
|
53 |
+
'the haystack (the default sampler behavior).')
|
54 |
+
flags.DEFINE_integer('seed', 42, 'Random seed to set')
|
55 |
+
|
56 |
+
flags.DEFINE_boolean('random_pos', True,
|
57 |
+
'Randomize the pos input common to all algos.')
|
58 |
+
flags.DEFINE_boolean('enforce_permutations', True,
|
59 |
+
'Whether to enforce permutation-type node pointers.')
|
60 |
+
flags.DEFINE_boolean('enforce_pred_as_input', True,
|
61 |
+
'Whether to change pred_h hints into pred inputs.')
|
62 |
+
flags.DEFINE_integer('batch_size', 32, 'Batch size used for training.')
|
63 |
+
flags.DEFINE_boolean('chunked_training', False,
|
64 |
+
'Whether to use chunking for training.')
|
65 |
+
flags.DEFINE_integer('chunk_length', 16,
|
66 |
+
'Time chunk length used for training (if '
|
67 |
+
'`chunked_training` is True.')
|
68 |
+
flags.DEFINE_integer('train_steps', 1000, 'Number of training iterations.')
|
69 |
+
flags.DEFINE_integer('eval_every', 50, 'Evaluation frequency (in steps).')
|
70 |
+
flags.DEFINE_integer('test_every', 500, 'Evaluation frequency (in steps).')
|
71 |
+
flags.DEFINE_integer('log_every', 50, 'Logging frequency (in steps).')
|
72 |
+
|
73 |
+
flags.DEFINE_integer('hidden_size', 128,
|
74 |
+
'Number of hidden units of the model.')
|
75 |
+
flags.DEFINE_integer('nb_heads', 1, 'Number of heads for GAT processors')
|
76 |
+
flags.DEFINE_integer('nb_msg_passing_steps', 1,
|
77 |
+
'Number of message passing steps to run per hint.')
|
78 |
+
flags.DEFINE_float('learning_rate', 0.001, 'Learning rate to use.')
|
79 |
+
flags.DEFINE_float('grad_clip_max_norm', 1.0,
|
80 |
+
'Gradient clipping by norm. 0.0 disables grad clipping')
|
81 |
+
flags.DEFINE_float('dropout_prob', 0.0, 'Dropout rate to use.')
|
82 |
+
flags.DEFINE_float('hint_teacher_forcing', 0.0,
|
83 |
+
'Probability that ground-truth teacher hints are encoded '
|
84 |
+
'during training instead of predicted hints. Only '
|
85 |
+
'pertinent in encoded_decoded modes.')
|
86 |
+
flags.DEFINE_enum('hint_mode', 'encoded_decoded',
|
87 |
+
['encoded_decoded', 'decoded_only', 'none'],
|
88 |
+
'How should hints be used? Note, each mode defines a '
|
89 |
+
'separate task, with various difficulties. `encoded_decoded` '
|
90 |
+
'requires the model to explicitly materialise hint sequences '
|
91 |
+
'and therefore is hardest, but also most aligned to the '
|
92 |
+
'underlying algorithmic rule. Hence, `encoded_decoded` '
|
93 |
+
'should be treated as the default mode for our benchmark. '
|
94 |
+
'In `decoded_only`, hints are only used for defining '
|
95 |
+
'reconstruction losses. Often, this will perform well, but '
|
96 |
+
'note that we currently do not make any efforts to '
|
97 |
+
'counterbalance the various hint losses. Hence, for certain '
|
98 |
+
'tasks, the best performance will now be achievable with no '
|
99 |
+
'hint usage at all (`none`).')
|
100 |
+
flags.DEFINE_enum('hint_repred_mode', 'soft', ['soft', 'hard', 'hard_on_eval'],
|
101 |
+
'How to process predicted hints when fed back as inputs.'
|
102 |
+
'In soft mode, we use softmaxes for categoricals, pointers '
|
103 |
+
'and mask_one, and sigmoids for masks. '
|
104 |
+
'In hard mode, we use argmax instead of softmax, and hard '
|
105 |
+
'thresholding of masks. '
|
106 |
+
'In hard_on_eval mode, soft mode is '
|
107 |
+
'used for training and hard mode is used for evaluation.')
|
108 |
+
flags.DEFINE_boolean('use_ln', True,
|
109 |
+
'Whether to use layer normalisation in the processor.')
|
110 |
+
flags.DEFINE_boolean('use_lstm', False,
|
111 |
+
'Whether to insert an LSTM after message passing.')
|
112 |
+
flags.DEFINE_integer('nb_triplet_fts', 8,
|
113 |
+
'How many triplet features to compute?')
|
114 |
+
|
115 |
+
flags.DEFINE_enum('encoder_init', 'xavier_on_scalars',
|
116 |
+
['default', 'xavier_on_scalars'],
|
117 |
+
'Initialiser to use for the encoders.')
|
118 |
+
flags.DEFINE_enum('processor_type', 'triplet_gmpnn',
|
119 |
+
['deepsets', 'mpnn', 'pgn', 'pgn_mask',
|
120 |
+
'triplet_mpnn', 'triplet_pgn', 'triplet_pgn_mask',
|
121 |
+
'gat', 'gatv2', 'gat_full', 'gatv2_full',
|
122 |
+
'gpgn', 'gpgn_mask', 'gmpnn',
|
123 |
+
'triplet_gpgn', 'triplet_gpgn_mask', 'triplet_gmpnn'],
|
124 |
+
'Processor type to use as the network P.')
|
125 |
+
|
126 |
+
flags.DEFINE_string('checkpoint_path', '../env/checkpoints',
|
127 |
+
'Path in which checkpoints are saved.')
|
128 |
+
flags.DEFINE_string('dataset_path', '/tmp/CLRS30',
|
129 |
+
'Path in which dataset is stored.')
|
130 |
+
flags.DEFINE_boolean('freeze_processor', False,
|
131 |
+
'Whether to freeze the processor of the model.')
|
132 |
+
|
133 |
+
FLAGS = flags.FLAGS
|
134 |
+
|
135 |
+
|
136 |
+
PRED_AS_INPUT_ALGOS = [
|
137 |
+
'binary_search',
|
138 |
+
'minimum',
|
139 |
+
'find_maximum_subarray',
|
140 |
+
'find_maximum_subarray_kadane',
|
141 |
+
'matrix_chain_order',
|
142 |
+
'lcs_length',
|
143 |
+
'optimal_bst',
|
144 |
+
'activity_selector',
|
145 |
+
'task_scheduling',
|
146 |
+
'naive_string_matcher',
|
147 |
+
'kmp_matcher',
|
148 |
+
'jarvis_march']
|
149 |
+
|
150 |
+
|
151 |
+
def unpack(v):
|
152 |
+
try:
|
153 |
+
return v.item() # DeviceArray
|
154 |
+
except (AttributeError, ValueError):
|
155 |
+
return v
|
156 |
+
|
157 |
+
|
158 |
+
def _iterate_sampler(sampler, batch_size):
|
159 |
+
while True:
|
160 |
+
yield sampler.next(batch_size)
|
161 |
+
|
162 |
+
|
163 |
+
def _maybe_download_dataset(dataset_path):
|
164 |
+
"""Download CLRS30 dataset if needed."""
|
165 |
+
dataset_folder = os.path.join(dataset_path, clrs.get_clrs_folder())
|
166 |
+
if os.path.isdir(dataset_folder):
|
167 |
+
logging.info('Dataset found at %s. Skipping download.', dataset_folder)
|
168 |
+
return dataset_folder
|
169 |
+
logging.info('Dataset not found in %s. Downloading...', dataset_folder)
|
170 |
+
|
171 |
+
clrs_url = clrs.get_dataset_gcp_url()
|
172 |
+
request = requests.get(clrs_url, allow_redirects=True)
|
173 |
+
clrs_file = os.path.join(dataset_path, os.path.basename(clrs_url))
|
174 |
+
os.makedirs(dataset_folder)
|
175 |
+
open(clrs_file, 'wb').write(request.content)
|
176 |
+
shutil.unpack_archive(clrs_file, extract_dir=dataset_folder)
|
177 |
+
os.remove(clrs_file)
|
178 |
+
return dataset_folder
|
179 |
+
|
180 |
+
|
181 |
+
def make_sampler(length: int,
|
182 |
+
rng: Any,
|
183 |
+
algorithm: str,
|
184 |
+
split: str,
|
185 |
+
batch_size: int,
|
186 |
+
multiplier: int,
|
187 |
+
randomize_pos: bool,
|
188 |
+
enforce_pred_as_input: bool,
|
189 |
+
enforce_permutations: bool,
|
190 |
+
chunked: bool,
|
191 |
+
chunk_length: int,
|
192 |
+
sampler_kwargs: Dict[str, Any]):
|
193 |
+
"""Create a sampler with given options.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
length: Size of samples (i.e., number of nodes in the graph).
|
197 |
+
A length of -1 will mean that the benchmark
|
198 |
+
dataset (for the given split) is used. Positive sizes will instantiate
|
199 |
+
samplers of the corresponding size.
|
200 |
+
rng: Numpy random state.
|
201 |
+
algorithm: The name of the algorithm to sample from.
|
202 |
+
split: 'train', 'val' or 'test'.
|
203 |
+
batch_size: Samples per batch.
|
204 |
+
multiplier: Integer multiplier for the number of samples in the dataset,
|
205 |
+
only used for positive sizes. Negative multiplier means infinite samples.
|
206 |
+
randomize_pos: Whether to randomize the `pos` input.
|
207 |
+
enforce_pred_as_input: Whether to convert fixed pred_h hints to inputs.
|
208 |
+
enforce_permutations: Whether to enforce permutation pointers.
|
209 |
+
chunked: Whether to chunk the dataset.
|
210 |
+
chunk_length: Unroll length of chunks, if `chunked` is True.
|
211 |
+
sampler_kwargs: Extra args passed to the sampler.
|
212 |
+
Returns:
|
213 |
+
A sampler (iterator), the number of samples in the iterator (negative
|
214 |
+
if infinite samples), and the spec.
|
215 |
+
"""
|
216 |
+
if length < 0: # load from file
|
217 |
+
dataset_folder = _maybe_download_dataset(FLAGS.dataset_path)
|
218 |
+
sampler, num_samples, spec = clrs.create_dataset(folder=dataset_folder,
|
219 |
+
algorithm=algorithm,
|
220 |
+
batch_size=batch_size,
|
221 |
+
split=split)
|
222 |
+
sampler = sampler.as_numpy_iterator()
|
223 |
+
else:
|
224 |
+
num_samples = clrs.CLRS30[split]['num_samples'] * multiplier
|
225 |
+
sampler, spec = clrs.build_sampler(
|
226 |
+
algorithm,
|
227 |
+
seed=rng.randint(2**32),
|
228 |
+
num_samples=num_samples,
|
229 |
+
length=length,
|
230 |
+
**sampler_kwargs,
|
231 |
+
)
|
232 |
+
sampler = _iterate_sampler(sampler, batch_size)
|
233 |
+
|
234 |
+
if randomize_pos:
|
235 |
+
sampler = clrs.process_random_pos(sampler, rng)
|
236 |
+
if enforce_pred_as_input and algorithm in PRED_AS_INPUT_ALGOS:
|
237 |
+
spec, sampler = clrs.process_pred_as_input(spec, sampler)
|
238 |
+
spec, sampler = clrs.process_permutations(spec, sampler, enforce_permutations)
|
239 |
+
if chunked:
|
240 |
+
sampler = clrs.chunkify(sampler, chunk_length)
|
241 |
+
return sampler, num_samples, spec
|
242 |
+
|
243 |
+
|
244 |
+
def make_multi_sampler(sizes, rng, **kwargs):
|
245 |
+
"""Create a sampler with cycling sample sizes."""
|
246 |
+
ss = []
|
247 |
+
tot_samples = 0
|
248 |
+
for length in sizes:
|
249 |
+
sampler, num_samples, spec = make_sampler(length, rng, **kwargs)
|
250 |
+
ss.append(sampler)
|
251 |
+
tot_samples += num_samples
|
252 |
+
|
253 |
+
def cycle_samplers():
|
254 |
+
while True:
|
255 |
+
for s in ss:
|
256 |
+
yield next(s)
|
257 |
+
return cycle_samplers(), tot_samples, spec
|
258 |
+
|
259 |
+
|
260 |
+
def _concat(dps, axis):
|
261 |
+
return jax.tree_util.tree_map(lambda *x: np.concatenate(x, axis), *dps)
|
262 |
+
|
263 |
+
|
264 |
+
def collect_and_eval(sampler, predict_fn, sample_count, rng_key, extras):
|
265 |
+
"""Collect batches of output and hint preds and evaluate them."""
|
266 |
+
processed_samples = 0
|
267 |
+
preds = []
|
268 |
+
outputs = []
|
269 |
+
while processed_samples < sample_count:
|
270 |
+
feedback = next(sampler)
|
271 |
+
batch_size = feedback.outputs[0].data.shape[0]
|
272 |
+
outputs.append(feedback.outputs)
|
273 |
+
new_rng_key, rng_key = jax.random.split(rng_key)
|
274 |
+
cur_preds, _ = predict_fn(new_rng_key, feedback.features)
|
275 |
+
preds.append(cur_preds)
|
276 |
+
processed_samples += batch_size
|
277 |
+
outputs = _concat(outputs, axis=0)
|
278 |
+
preds = _concat(preds, axis=0)
|
279 |
+
out = clrs.evaluate(outputs, preds)
|
280 |
+
if extras:
|
281 |
+
out.update(extras)
|
282 |
+
return {k: unpack(v) for k, v in out.items()}
|
283 |
+
|
284 |
+
|
285 |
+
def create_samplers(rng, train_lengths: List[int]):
|
286 |
+
"""Create all the samplers."""
|
287 |
+
train_samplers = []
|
288 |
+
val_samplers = []
|
289 |
+
val_sample_counts = []
|
290 |
+
test_samplers = []
|
291 |
+
test_sample_counts = []
|
292 |
+
spec_list = []
|
293 |
+
|
294 |
+
for algo_idx, algorithm in enumerate(FLAGS.algorithms):
|
295 |
+
# Make full dataset pipeline run on CPU (including prefetching).
|
296 |
+
with tf.device('/cpu:0'):
|
297 |
+
|
298 |
+
if algorithm in ['naive_string_matcher', 'kmp_matcher']:
|
299 |
+
# Fixed haystack + needle; variability will be in needle
|
300 |
+
# Still, for chunked training, we maintain as many samplers
|
301 |
+
# as train lengths, since, for each length there is a separate state,
|
302 |
+
# and we must keep the 1:1 relationship between states and samplers.
|
303 |
+
max_length = max(train_lengths)
|
304 |
+
if max_length > 0: # if < 0, we are using the benchmark data
|
305 |
+
max_length = (max_length * 5) // 4
|
306 |
+
train_lengths = [max_length]
|
307 |
+
if FLAGS.chunked_training:
|
308 |
+
train_lengths = train_lengths * len(train_lengths)
|
309 |
+
|
310 |
+
logging.info('Creating samplers for algo %s', algorithm)
|
311 |
+
|
312 |
+
p = tuple([0.1 + 0.1 * i for i in range(9)])
|
313 |
+
if p and algorithm in ['articulation_points', 'bridges',
|
314 |
+
'mst_kruskal', 'bipartite_matching']:
|
315 |
+
# Choose a lower connection probability for the above algorithms,
|
316 |
+
# otherwise trajectories are very long
|
317 |
+
p = tuple(np.array(p) / 2)
|
318 |
+
length_needle = FLAGS.length_needle
|
319 |
+
sampler_kwargs = dict(p=p, length_needle=length_needle)
|
320 |
+
if length_needle == 0:
|
321 |
+
sampler_kwargs.pop('length_needle')
|
322 |
+
|
323 |
+
common_sampler_args = dict(
|
324 |
+
algorithm=FLAGS.algorithms[algo_idx],
|
325 |
+
rng=rng,
|
326 |
+
enforce_pred_as_input=FLAGS.enforce_pred_as_input,
|
327 |
+
enforce_permutations=FLAGS.enforce_permutations,
|
328 |
+
chunk_length=FLAGS.chunk_length,
|
329 |
+
)
|
330 |
+
|
331 |
+
train_args = dict(sizes=train_lengths,
|
332 |
+
split='train',
|
333 |
+
batch_size=FLAGS.batch_size,
|
334 |
+
multiplier=-1,
|
335 |
+
randomize_pos=FLAGS.random_pos,
|
336 |
+
chunked=FLAGS.chunked_training,
|
337 |
+
sampler_kwargs=sampler_kwargs,
|
338 |
+
**common_sampler_args)
|
339 |
+
train_sampler, _, spec = make_multi_sampler(**train_args)
|
340 |
+
|
341 |
+
mult = clrs.CLRS_30_ALGS_SETTINGS[algorithm]['num_samples_multiplier']
|
342 |
+
val_args = dict(sizes=[np.amax(train_lengths)],
|
343 |
+
split='val',
|
344 |
+
batch_size=32,
|
345 |
+
multiplier=2 * mult,
|
346 |
+
randomize_pos=FLAGS.random_pos,
|
347 |
+
chunked=False,
|
348 |
+
sampler_kwargs=sampler_kwargs,
|
349 |
+
**common_sampler_args)
|
350 |
+
val_sampler, val_samples, spec = make_multi_sampler(**val_args)
|
351 |
+
|
352 |
+
test_args = dict(sizes=[-1],
|
353 |
+
split='test',
|
354 |
+
batch_size=32,
|
355 |
+
multiplier=2 * mult,
|
356 |
+
randomize_pos=False,
|
357 |
+
chunked=False,
|
358 |
+
sampler_kwargs={},
|
359 |
+
**common_sampler_args)
|
360 |
+
test_sampler, test_samples, spec = make_multi_sampler(**test_args)
|
361 |
+
|
362 |
+
spec_list.append(spec)
|
363 |
+
train_samplers.append(train_sampler)
|
364 |
+
val_samplers.append(val_sampler)
|
365 |
+
val_sample_counts.append(val_samples)
|
366 |
+
test_samplers.append(test_sampler)
|
367 |
+
test_sample_counts.append(test_samples)
|
368 |
+
|
369 |
+
return (train_samplers,
|
370 |
+
val_samplers, val_sample_counts,
|
371 |
+
test_samplers, test_sample_counts,
|
372 |
+
spec_list)
|
373 |
+
|
374 |
+
|
375 |
+
def get_score(submission_folder):
|
376 |
+
FLAGS(["eval.py"])
|
377 |
+
if FLAGS.hint_mode == 'encoded_decoded':
|
378 |
+
encode_hints = True
|
379 |
+
decode_hints = True
|
380 |
+
elif FLAGS.hint_mode == 'decoded_only':
|
381 |
+
encode_hints = False
|
382 |
+
decode_hints = True
|
383 |
+
elif FLAGS.hint_mode == 'none':
|
384 |
+
encode_hints = False
|
385 |
+
decode_hints = False
|
386 |
+
else:
|
387 |
+
raise ValueError('Hint mode not in {encoded_decoded, decoded_only, none}.')
|
388 |
+
|
389 |
+
train_lengths = [int(x) for x in FLAGS.train_lengths]
|
390 |
+
|
391 |
+
rng = np.random.RandomState(FLAGS.seed)
|
392 |
+
rng_key = jax.random.PRNGKey(rng.randint(2**32))
|
393 |
+
|
394 |
+
checkpoint_path = os.path.join(submission_folder, 'checkpoints')
|
395 |
+
|
396 |
+
spec_list = pickle.load(open(os.path.join(checkpoint_path, 'spec_list.pkl'), 'rb'))
|
397 |
+
|
398 |
+
# Create samplers
|
399 |
+
(train_samplers,
|
400 |
+
val_samplers, val_sample_counts,
|
401 |
+
test_samplers, test_sample_counts,
|
402 |
+
spec_list) = create_samplers(rng, train_lengths)
|
403 |
+
|
404 |
+
# load spec_list
|
405 |
+
model_params = pickle.load(open(os.path.join(checkpoint_path, 'model_params.pkl'), 'rb'))
|
406 |
+
processor_type, use_ln, nb_triplet_fts, nb_heads = model_params["processor_factory"]
|
407 |
+
model_params["processor_factory"] = clrs.get_processor_factory(
|
408 |
+
processor_type,
|
409 |
+
use_ln=use_ln,
|
410 |
+
nb_triplet_fts=nb_triplet_fts,
|
411 |
+
nb_heads=nb_heads
|
412 |
+
)
|
413 |
+
model_params["checkpoint_path"]=checkpoint_path
|
414 |
+
|
415 |
+
eval_model = BaselineModel(
|
416 |
+
spec=spec_list,
|
417 |
+
dummy_trajectory=[next(t) for t in val_samplers],
|
418 |
+
**model_params
|
419 |
+
)
|
420 |
+
|
421 |
+
feedback_list = [next(t) for t in train_samplers]
|
422 |
+
|
423 |
+
# Initialize model.
|
424 |
+
all_features = [f.features for f in feedback_list]
|
425 |
+
eval_model.init(all_features, FLAGS.seed + 1)
|
426 |
+
|
427 |
+
|
428 |
+
logging.set_verbosity(logging.INFO)
|
429 |
+
|
430 |
+
logging.info('Restoring best model from checkpoint...')
|
431 |
+
eval_model.restore_model('best.pkl', only_load_processor=False)
|
432 |
+
|
433 |
+
for algo_idx in range(len(train_samplers)):
|
434 |
+
new_rng_key, rng_key = jax.random.split(rng_key)
|
435 |
+
val_stats = collect_and_eval(
|
436 |
+
val_samplers[algo_idx],
|
437 |
+
functools.partial(eval_model.predict, algorithm_index=algo_idx),
|
438 |
+
val_sample_counts[algo_idx],
|
439 |
+
new_rng_key,
|
440 |
+
extras = {})
|
441 |
+
# logging.info('(val) algo %s: %s', FLAGS.algorithms[algo_idx], val_stats)
|
442 |
+
|
443 |
+
new_rng_key, rng_key = jax.random.split(rng_key)
|
444 |
+
test_stats = collect_and_eval(
|
445 |
+
test_samplers[algo_idx],
|
446 |
+
functools.partial(eval_model.predict, algorithm_index=algo_idx),
|
447 |
+
test_sample_counts[algo_idx],
|
448 |
+
new_rng_key,
|
449 |
+
extras = {})
|
450 |
+
# logging.info('(test) algo %s : %s', FLAGS.algorithms[algo_idx], test_stats)
|
451 |
+
return test_stats['score']
|
452 |
+
|
453 |
+
if __name__ == '__main__':
|
454 |
+
app.run(get_score)
|
benchmarks/CLRS/scripts/requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
absl-py>=0.13.0
|
2 |
+
attrs>=21.4.0
|
3 |
+
chex>=0.0.8
|
4 |
+
dm-haiku>=0.0.4
|
5 |
+
jax>=0.2.18
|
6 |
+
jaxlib>=0.1.69
|
7 |
+
numpy>=1.21.1
|
8 |
+
opt-einsum>=3.3.0
|
9 |
+
optax>=0.0.9
|
10 |
+
six>=1.16.0
|
11 |
+
tensorflow>=2.9.0
|
12 |
+
tfds-nightly==4.5.2.dev202204190046
|
13 |
+
toolz>=0.11.1
|
benchmarks/CLRS/scripts/research_problem.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Improve the baseline model performance on the task floyd_warshall in The CLRS Algorithmic Reasoning Benchmark. The dataset description is available in data_description.txt, and the baseline model architecture description is available in baseline_model_description.txt. To run the baseline model, execute train.py. Note that the core message passing function of the baseline model is implemented in function get_triplet_msgs (L301 in processors.py). You can modify this function to improve the baseline model performance. You can also modify other parts of the baseline model and training script to improve its performance, as long as the final model is still loadable by calling BaselineModel class as in L415 in train.py.
|
2 |
+
|
3 |
+
When you submit your final answer, you will be evaluated on the performance of the checkpoint checkpoints/best.pkl saved by train.py. Note that the final model must still be loadable by calling BaselineModel class as in L415 in train.py and with the saved spec_list.pkl and model_params.pkl.
|
benchmarks/CLRS/scripts/source_code.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
https://github.com/deepmind/clrs/blob/master/clrs/examples/run.py
|
benchmarks/amp-parkinsons-disease-progression-prediction/env/data_description.txt
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Dataset Description
|
2 |
+
The goal of this competition is to predict the course of Parkinson's disease (PD) using protein abundance data. The complete set of proteins involved in PD remains an open research question and any proteins that have predictive value are likely worth investigating further. The core of the dataset consists of protein abundance values derived from mass spectrometry readings of cerebrospinal fluid (CSF) samples gathered from several hundred patients. Each patient contributed several samples over the course of multiple years while they also took assessments of PD severity.
|
3 |
+
|
4 |
+
This is a time-series code competition: you will receive test set data and make predictions with a time-series API. See the evaluation_details.txt for details.
|
5 |
+
|
6 |
+
Files
|
7 |
+
train_peptides.csv Mass spectrometry data at the peptide level. Peptides are the component subunits of proteins.
|
8 |
+
|
9 |
+
visit_id - ID code for the visit.
|
10 |
+
visit_month - The month of the visit, relative to the first visit by the patient.
|
11 |
+
patient_id - An ID code for the patient.
|
12 |
+
UniProt - The UniProt ID code for the associated protein. There are often several peptides per protein.
|
13 |
+
Peptide - The sequence of amino acids included in the peptide. See this table for the relevant codes. Some rare annotations may not be included in the table. The test set may include peptides not found in the train set.
|
14 |
+
PeptideAbundance - The frequency of the amino acid in the sample.
|
15 |
+
train_proteins.csv Protein expression frequencies aggregated from the peptide level data.
|
16 |
+
|
17 |
+
visit_id - ID code for the visit.
|
18 |
+
visit_month - The month of the visit, relative to the first visit by the patient.
|
19 |
+
patient_id - An ID code for the patient.
|
20 |
+
UniProt - The UniProt ID code for the associated protein. There are often several peptides per protein. The test set may include proteins not found in the train set.
|
21 |
+
NPX - Normalized protein expression. The frequency of the protein's occurrence in the sample. May not have a 1:1 relationship with the component peptides as some proteins contain repeated copies of a given peptide.
|
22 |
+
train_clinical_data.csv
|
23 |
+
|
24 |
+
visit_id - ID code for the visit.
|
25 |
+
visit_month - The month of the visit, relative to the first visit by the patient.
|
26 |
+
patient_id - An ID code for the patient.
|
27 |
+
updrs_[1-4] - The patient's score for part N of the Unified Parkinson's Disease Rating Scale. Higher numbers indicate more severe symptoms. Each sub-section covers a distinct category of symptoms, such as mood and behavior for Part 1 and motor functions for Part 3.
|
28 |
+
upd23b_clinical_state_on_medication - Whether or not the patient was taking medication such as Levodopa during the UPDRS assessment. Expected to mainly affect the scores for Part 3 (motor function). These medications wear off fairly quickly (on the order of one day) so it's common for patients to take the motor function exam twice in a single month, both with and without medication.
|
29 |
+
supplemental_clinical_data.csv Clinical records without any associated CSF samples. This data is intended to provide additional context about the typical progression of Parkinsons. Uses the same columns as train_clinical_data.csv.
|
30 |
+
|
31 |
+
example_test_files/ Data intended to illustrate how the API functions. Includes the same columns delivered by the API (ie no updrs columns).
|
32 |
+
|
33 |
+
public_timeseries_testing_util.py A file for running custom API tests.
|
benchmarks/amp-parkinsons-disease-progression-prediction/env/evaluation_details.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Submissions are evaluated on SMAPE between forecasts and actual values. We define SMAPE = 0 when the actual and predicted values are both 0.
|
2 |
+
|
3 |
+
For each patient visit where a protein/peptide sample was taken you will need to estimate both their UPDRS scores for that visit and predict their scores for any potential visits 6, 12, and 24 months later. Predictions for any visits that didn't ultimately take place are ignored.
|
4 |
+
|
5 |
+
You must submit to this competition using the provided python time-series API, which ensures that models do not peek forward in time. To use the API, follow this template in Kaggle Notebooks:
|
6 |
+
|
7 |
+
from public_timeseries_testing_util import MockApi
|
8 |
+
env = MockApi.make_env() # initialize the environment
|
9 |
+
iter_test = env.iter_test() # an iterator which loops over the test files
|
10 |
+
for (test, test_peptides, test_proteins, sample_submission) in iter_test:
|
11 |
+
sample_prediction_df['rating'] = np.arange(len(sample_prediction)) # make your predictions here
|
12 |
+
env.predict(sample_prediction_df) # register your predictions
|
benchmarks/amp-parkinsons-disease-progression-prediction/env/public_timeseries_testing_util.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
An unlocked version of the timeseries API intended for testing alternate inputs.
|
3 |
+
Mirrors the production timeseries API in the crucial respects, but won't be as fast.
|
4 |
+
|
5 |
+
ONLY works afer the first three variables in MockAPI.__init__ are populated.
|
6 |
+
'''
|
7 |
+
|
8 |
+
from typing import Sequence, Tuple
|
9 |
+
|
10 |
+
import pandas as pd
|
11 |
+
|
12 |
+
|
13 |
+
class MockApi:
|
14 |
+
def __init__(self):
|
15 |
+
'''
|
16 |
+
YOU MUST UPDATE THE FIRST THREE LINES of this method.
|
17 |
+
They've been intentionally left in an invalid state.
|
18 |
+
|
19 |
+
Variables to set:
|
20 |
+
input_paths: a list of two or more paths to the csv files to be served
|
21 |
+
group_id_column: the column that identifies which groups of rows the API should serve.
|
22 |
+
A call to iter_test serves all rows of all dataframes with the current group ID value.
|
23 |
+
export_group_id_column: if true, the dataframes iter_test serves will include the group_id_column values.
|
24 |
+
'''
|
25 |
+
self.input_paths: Sequence[str] = [
|
26 |
+
'example_test_files/test.csv',
|
27 |
+
'example_test_files/test_peptides.csv',
|
28 |
+
'example_test_files/test_proteins.csv',
|
29 |
+
'example_test_files/sample_submission.csv',
|
30 |
+
]
|
31 |
+
self.group_id_column: str = 'visit_month'
|
32 |
+
self.export_group_id_column: bool = True
|
33 |
+
# iter_test is only designed to support at least two dataframes, such as test and sample_submission
|
34 |
+
assert len(self.input_paths) >= 2
|
35 |
+
|
36 |
+
self._status = 'initialized'
|
37 |
+
self.predictions = []
|
38 |
+
|
39 |
+
def iter_test(self) -> Tuple[pd.DataFrame]:
|
40 |
+
'''
|
41 |
+
Loads all of the dataframes specified in self.input_paths,
|
42 |
+
then yields all rows in those dataframes that equal the current self.group_id_column value.
|
43 |
+
'''
|
44 |
+
if self._status != 'initialized':
|
45 |
+
|
46 |
+
raise Exception('WARNING: the real API can only iterate over `iter_test()` once.')
|
47 |
+
|
48 |
+
dataframes = []
|
49 |
+
for pth in self.input_paths:
|
50 |
+
dataframes.append(pd.read_csv(pth, low_memory=False))
|
51 |
+
group_order = dataframes[0][self.group_id_column].drop_duplicates().tolist()
|
52 |
+
dataframes = [df.set_index(self.group_id_column) for df in dataframes]
|
53 |
+
|
54 |
+
for group_id in group_order:
|
55 |
+
self._status = 'prediction_needed'
|
56 |
+
current_data = []
|
57 |
+
for df in dataframes:
|
58 |
+
try:
|
59 |
+
cur_df = df.loc[group_id].copy()
|
60 |
+
# returning single line dataframes from df.loc requires special handling
|
61 |
+
if not isinstance(cur_df, pd.DataFrame):
|
62 |
+
cur_df = pd.DataFrame({a: b for a, b in zip(cur_df.index.values, cur_df.values)}, index=[group_id])
|
63 |
+
cur_df = cur_df.index.rename(self.group_id_column)
|
64 |
+
except KeyError:
|
65 |
+
cur_df = df.loc[[]].copy()
|
66 |
+
cur_df = cur_df.reset_index(drop=not(self.export_group_id_column))
|
67 |
+
current_data.append(cur_df)
|
68 |
+
yield tuple(current_data)
|
69 |
+
|
70 |
+
while self._status != 'prediction_received':
|
71 |
+
print('You must call `predict()` successfully before you can continue with `iter_test()`', flush=True)
|
72 |
+
yield None
|
73 |
+
|
74 |
+
with open('submission.csv', 'w') as f_open:
|
75 |
+
pd.concat(self.predictions).to_csv(f_open, index=False)
|
76 |
+
self._status = 'finished'
|
77 |
+
|
78 |
+
def predict(self, user_predictions: pd.DataFrame):
|
79 |
+
'''
|
80 |
+
Accepts and stores the user's predictions and unlocks iter_test once that is done
|
81 |
+
'''
|
82 |
+
if self._status == 'finished':
|
83 |
+
raise Exception('You have already made predictions for the full test set.')
|
84 |
+
if self._status != 'prediction_needed':
|
85 |
+
raise Exception('You must get the next test sample from `iter_test()` first.')
|
86 |
+
if not isinstance(user_predictions, pd.DataFrame):
|
87 |
+
raise Exception('You must provide a DataFrame.')
|
88 |
+
|
89 |
+
self.predictions.append(user_predictions)
|
90 |
+
self._status = 'prediction_received'
|
91 |
+
|
92 |
+
|
93 |
+
def make_env():
|
94 |
+
return MockApi()
|
benchmarks/amp-parkinsons-disease-progression-prediction/env/train.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
from sklearn.preprocessing import StandardScaler
|
4 |
+
from sklearn.ensemble import RandomForestRegressor
|
5 |
+
from public_timeseries_testing_util import MockApi
|
6 |
+
from sklearn.metrics import make_scorer
|
7 |
+
from sklearn.model_selection import KFold, GroupKFold, cross_val_score
|
8 |
+
from sklearn.utils import check_consistent_length
|
9 |
+
|
10 |
+
# Define the metric
|
11 |
+
def smapep1(y_true, y_pred):
|
12 |
+
"""SMAPE of y+1, a nonnegative float, smaller is better
|
13 |
+
|
14 |
+
Parameters: y_true, y_pred: array-like
|
15 |
+
|
16 |
+
Returns 100 for 100 % error.
|
17 |
+
y_true may have missing values.
|
18 |
+
"""
|
19 |
+
check_consistent_length(y_true, y_pred)
|
20 |
+
y_true = np.array(y_true, copy=False).ravel()
|
21 |
+
y_pred = np.array(y_pred, copy=False).ravel()
|
22 |
+
y_true, y_pred = y_true[np.isfinite(y_true)], y_pred[np.isfinite(y_true)]
|
23 |
+
if (y_true < 0).any(): raise ValueError('y_true < 0')
|
24 |
+
if (y_pred < 0).any(): raise ValueError('y_pred < 0')
|
25 |
+
denominator = (y_true + y_pred) / 2 + 1
|
26 |
+
ape = np.abs(y_pred - y_true) / denominator
|
27 |
+
return np.average(ape) * 100
|
28 |
+
|
29 |
+
# The scorer returns nonpositive values so that greater is better.
|
30 |
+
# It will be used as an argument to cross_val_score
|
31 |
+
smapep1_scorer = make_scorer(smapep1, greater_is_better=False)
|
32 |
+
|
33 |
+
def get_predictions(my_train, model):
|
34 |
+
|
35 |
+
# Forecast
|
36 |
+
my_train = my_train.fillna(0)
|
37 |
+
result = pd.DataFrame(columns = ['prediction_id', 'rating'])
|
38 |
+
final = []
|
39 |
+
target = ["updrs_1", "updrs_2", "updrs_3", "updrs_4"]
|
40 |
+
|
41 |
+
for u in target:
|
42 |
+
|
43 |
+
# Predict
|
44 |
+
X = my_train["visit_month"]
|
45 |
+
|
46 |
+
predict = model[u].predict(X.values.reshape(-1, 1)).tolist()
|
47 |
+
complete_result = my_train[["visit_id",'visit_month']].values.tolist()
|
48 |
+
for index in range(len(complete_result)):
|
49 |
+
complete_result[index].extend(predict[index])
|
50 |
+
temp = pd.DataFrame(complete_result,
|
51 |
+
columns = ["visit_id",'visit_month',u +'_plus_0_months',
|
52 |
+
u +'_plus_6_months',
|
53 |
+
u +'_plus_12_months',
|
54 |
+
u +'_plus_24_months'])
|
55 |
+
temp = temp.melt( id_vars=["visit_id",'visit_month'],
|
56 |
+
value_vars=[ u +'_plus_0_months' , u +'_plus_6_months',
|
57 |
+
u +'_plus_12_months',u +"_plus_24_months"],
|
58 |
+
value_name = 'rating')
|
59 |
+
temp['prediction_id'] = temp['visit_id'] + '_' + temp['variable']
|
60 |
+
|
61 |
+
final.append(temp[['prediction_id','rating']])
|
62 |
+
final = pd.concat(final)
|
63 |
+
final = final.drop_duplicates(subset=['prediction_id', 'rating'])
|
64 |
+
return final
|
65 |
+
|
66 |
+
if __name__ == "__main__":
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
target = ["updrs_1", "updrs_2", "updrs_3", "updrs_4"]
|
71 |
+
data_proteins = pd.read_csv('train_proteins.csv')
|
72 |
+
data_clinical = pd.read_csv('train_clinical_data.csv')
|
73 |
+
data_peptides = pd.read_csv('train_peptides.csv')
|
74 |
+
data_supplemental = pd.read_csv('supplemental_clinical_data.csv')
|
75 |
+
merged_data = pd.concat([data_clinical, data_supplemental])
|
76 |
+
|
77 |
+
## TODO: data cleaning and feature engineering
|
78 |
+
# Right now, we only use the month data and the target data
|
79 |
+
id_list = merged_data['patient_id'].unique().tolist()
|
80 |
+
data_for_train = {}
|
81 |
+
for u in target:
|
82 |
+
final = []
|
83 |
+
for id_ in id_list:
|
84 |
+
infor_of_id = merged_data[merged_data['patient_id'] == id_]
|
85 |
+
month_per_id = infor_of_id.visit_month.tolist()
|
86 |
+
for month in month_per_id:
|
87 |
+
check = [month, id_]
|
88 |
+
for plus in [0,6,12,24]:
|
89 |
+
if month + plus in month_per_id :
|
90 |
+
month_value = infor_of_id[infor_of_id.visit_month == month+plus][u].values[0]
|
91 |
+
if month_value != np.nan:
|
92 |
+
check.append(month_value)
|
93 |
+
if len(check) == 6:
|
94 |
+
final.append(check)
|
95 |
+
check = pd.DataFrame(final,columns = ['month', 'patient_id',u+'+0',u+'+6',u+'+12',u+'+24'])
|
96 |
+
data_for_train[u] = check.dropna()
|
97 |
+
|
98 |
+
|
99 |
+
## train model
|
100 |
+
model = {}
|
101 |
+
overall_score = []
|
102 |
+
target = ["updrs_1", "updrs_2", "updrs_3", "updrs_4"]
|
103 |
+
|
104 |
+
for i, u in enumerate(target):
|
105 |
+
|
106 |
+
# Train data
|
107 |
+
X = data_for_train[u]['month']
|
108 |
+
y = data_for_train[u].iloc[:,2:6]
|
109 |
+
trained = RandomForestRegressor().fit(X.values.reshape(-1, 1), y)
|
110 |
+
# Save model
|
111 |
+
model[u] = trained
|
112 |
+
|
113 |
+
## cross validation and print results
|
114 |
+
print('Cross-validation scores')
|
115 |
+
|
116 |
+
cvs = cross_val_score(RandomForestRegressor(),
|
117 |
+
X=X.values.reshape(-1, 1), y=y,
|
118 |
+
groups=data_for_train[u]['patient_id'],
|
119 |
+
scoring=smapep1_scorer,
|
120 |
+
cv=GroupKFold(n_splits=8),
|
121 |
+
error_score='raise')
|
122 |
+
print([f'updrs_{i}:'], -cvs.round(1), -cvs.mean().round(1))
|
123 |
+
overall_score.append(-cvs)
|
124 |
+
print(f'Overall cv score of the group model: {np.array(overall_score).mean():.2f}')
|
125 |
+
|
126 |
+
|
127 |
+
|
128 |
+
## save to submission.csv file for the test set by using this following API call
|
129 |
+
|
130 |
+
env = MockApi()
|
131 |
+
iter_test = env.iter_test() # an iterator which loops over the test files
|
132 |
+
|
133 |
+
# The API will deliver four dataframes in this specific order:
|
134 |
+
for iteration, (test_clinical_data, test_peptides, test_proteins, sample_submission) in enumerate(iter_test):
|
135 |
+
# TODO - make your predictions here by modifying 'rating' sample_submission dataframe
|
136 |
+
pred = get_predictions(test_clinical_data, model).round(0)
|
137 |
+
|
138 |
+
for index in sample_submission['prediction_id']:
|
139 |
+
sample_submission.loc[sample_submission['prediction_id']==index, 'rating'] = pred[pred['prediction_id']==index]['rating'].values
|
140 |
+
|
141 |
+
env.predict(sample_submission) # register your predictions
|
benchmarks/amp-parkinsons-disease-progression-prediction/scripts/eval.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import sys
|
3 |
+
import os
|
4 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../env"))
|
5 |
+
from importlib import reload
|
6 |
+
import train
|
7 |
+
reload(train)
|
8 |
+
import pandas as pd
|
9 |
+
from train import smapep1, check_consistent_length
|
10 |
+
|
11 |
+
|
12 |
+
def get_score(submission_folder = "../env"):
|
13 |
+
submission_path = os.path.join(submission_folder, "submission.csv")
|
14 |
+
solution = pd.read_csv(os.path.join(os.path.dirname(__file__), "answer.csv"))
|
15 |
+
submission = pd.read_csv(submission_path)
|
16 |
+
|
17 |
+
s = smapep1(solution["rating"], submission["rating"])
|
18 |
+
return s
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
print(get_score())
|
benchmarks/amp-parkinsons-disease-progression-prediction/scripts/prepare.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import pandas as pd
|
3 |
+
import random
|
4 |
+
import os
|
5 |
+
|
6 |
+
taskname = "amp-parkinsons-disease-progression-prediction"
|
7 |
+
download_dir = "../env"
|
8 |
+
|
9 |
+
input(f"Consent to the competition at https://www.kaggle.com/competitions/{taskname}/data; Press any key after you have accepted the rules online.")
|
10 |
+
|
11 |
+
subprocess.run(["kaggle", "competitions", "download", "-c", taskname], cwd=download_dir)
|
12 |
+
subprocess.run(["unzip", "-n", f"{taskname}.zip"], cwd=download_dir)
|
13 |
+
subprocess.run(["rm", f"{taskname}.zip"], cwd=download_dir)
|
14 |
+
subprocess.run(["rm", "-r", "amp_pd_peptide"], cwd=download_dir)
|
15 |
+
subprocess.run(["rm", "-r", "amp_pd_peptide_310"], cwd=download_dir)
|
16 |
+
|
17 |
+
# ## split train to train and test in env
|
18 |
+
|
19 |
+
data_proteins = pd.read_csv(f'{download_dir}/train_proteins.csv')
|
20 |
+
data_clinical = pd.read_csv(f'{download_dir}/train_clinical_data.csv')
|
21 |
+
data_peptides = pd.read_csv(f'{download_dir}/train_peptides.csv')
|
22 |
+
data_supplemental = pd.read_csv(f'{download_dir}/supplemental_clinical_data.csv')
|
23 |
+
|
24 |
+
random.seed(42)
|
25 |
+
patient_id = data_clinical['patient_id'].unique()
|
26 |
+
test_patient_id = random.sample(patient_id.tolist(), 2)
|
27 |
+
train_patient_id = [x for x in patient_id if x not in test_patient_id]
|
28 |
+
|
29 |
+
data_proteins[data_proteins['patient_id'].isin(train_patient_id)].to_csv(f'{download_dir}/train_proteins.csv', index=False)
|
30 |
+
data_clinical[data_clinical['patient_id'].isin(train_patient_id)].to_csv(f'{download_dir}/train_clinical_data.csv', index=False)
|
31 |
+
data_peptides[data_peptides['patient_id'].isin(train_patient_id)].to_csv(f'{download_dir}/train_peptides.csv', index=False)
|
32 |
+
data_supplemental[data_supplemental['patient_id'].isin(train_patient_id)].to_csv(f'{download_dir}/supplemental_clinical_data.csv', index=False)
|
33 |
+
|
34 |
+
data_proteins[data_proteins['patient_id'].isin(test_patient_id)].to_csv(f'{download_dir}/example_test_files/test_proteins.csv', index=False)
|
35 |
+
data_peptides[data_peptides['patient_id'].isin(test_patient_id)].to_csv(f'{download_dir}/example_test_files/test_peptides.csv', index=False)
|
36 |
+
test_clinical = data_clinical[data_clinical['patient_id'].isin(test_patient_id)]
|
37 |
+
|
38 |
+
|
39 |
+
# Create test.csv
|
40 |
+
temp_list = []
|
41 |
+
for i in range(1, 5):
|
42 |
+
temp = test_clinical.copy()
|
43 |
+
temp['level_3'] = i
|
44 |
+
temp['updrs_test'] = f'updrs_{i}'
|
45 |
+
temp_list.append(temp)
|
46 |
+
mock_train = pd.concat(temp_list)
|
47 |
+
mock_train['row_id'] = (mock_train[['patient_id', 'visit_month', 'level_3']]
|
48 |
+
.apply((lambda r: f"{r.patient_id}_{int(r.visit_month)}_updrs_{r.level_3}"), axis=1))
|
49 |
+
mock_train[['visit_id', 'patient_id', 'visit_month','row_id', 'updrs_test']].to_csv(f'{download_dir}/example_test_files/test.csv', index=False)
|
50 |
+
|
51 |
+
# Create sample_submission.csv
|
52 |
+
temp_list = []
|
53 |
+
for wait in [0, 6, 12, 24]:
|
54 |
+
temp = mock_train.copy()
|
55 |
+
temp['wait'] = wait
|
56 |
+
temp_list.append(temp)
|
57 |
+
y = pd.concat(temp_list)
|
58 |
+
y = y[y.visit_month + y.wait <= 108]
|
59 |
+
y['prediction_id'] = (y[['patient_id', 'visit_month', 'wait', 'level_3']]
|
60 |
+
.apply((lambda r: f"{r.patient_id}_{int(r.visit_month)}_updrs_{r.level_3}_plus_{r.wait}_months"), axis=1))
|
61 |
+
|
62 |
+
def get_rating(row):
|
63 |
+
rating = test_clinical[test_clinical["visit_id"] == f'{row.patient_id}_{int(row.visit_month) + int(row.wait) }' ][f'updrs_{row.level_3}']
|
64 |
+
if len(rating) == 0:
|
65 |
+
return None
|
66 |
+
return rating.item()
|
67 |
+
|
68 |
+
y['rating'] = (y[['patient_id', 'visit_month', 'wait', 'level_3']].apply(get_rating, axis=1))
|
69 |
+
y = y.dropna()
|
70 |
+
y[['prediction_id', 'rating', 'visit_month']].to_csv(f'answer.csv', index=False)
|
71 |
+
|
72 |
+
y['rating'] = 0
|
73 |
+
y[['prediction_id', 'rating', 'visit_month']].to_csv(f'{download_dir}/example_test_files/sample_submission.csv', index=False)
|
74 |
+
|
75 |
+
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
|
benchmarks/amp-parkinsons-disease-progression-prediction/scripts/read_only_files.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
example_test_files/*
|
2 |
+
./supplemental_clinical_data.csv
|
3 |
+
./train_clinical_data.csv
|
4 |
+
./train_peptide.csv
|
5 |
+
./train_protein.csv
|
benchmarks/amp-parkinsons-disease-progression-prediction/scripts/research_problem.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
Go through the data_description.txt file to understand the data and the machine learning task. You can summarize it in your research logs to keep track of what all you have to do.
|
2 |
+
Then fill in the provided train.py script to train a model and iterate over different models or feature selections to get a better performance (for SMAPE score the lower is better). Finally, you should submit the predictions of your best model for the test set as a submission.csv as described in the evaluation_details.txt file.
|
3 |
+
Never try to read any csv files directly. Do not forget to execute the changes you made to check for performance.
|
benchmarks/amp-parkinsons-disease-progression-prediction/scripts/source_code.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
https://www.kaggle.com/code/dangkhanhle/test-model
|
2 |
+
https://www.kaggle.com/code/ambrosm/pdpp-linear-and-isotonic-groups/notebook
|
benchmarks/babylm/env/babyLM_for_hf.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import datasets
|
3 |
+
|
4 |
+
_CITATION = """
|
5 |
+
"""
|
6 |
+
|
7 |
+
_DESCRIPTION = """\
|
8 |
+
BabyLM data
|
9 |
+
"""
|
10 |
+
_HOMEPAGE = "https://babylm.github.io/"
|
11 |
+
_LICENSE = "????"
|
12 |
+
_DATA_URL = "./babylm_data"
|
13 |
+
|
14 |
+
|
15 |
+
class babyLMConfig(datasets.BuilderConfig):
|
16 |
+
"""BuilderConfig for babyLM."""
|
17 |
+
|
18 |
+
def __init__(self, data_url, **kwargs):
|
19 |
+
"""BuilderConfig for babyLM
|
20 |
+
Args:
|
21 |
+
data_url: `string`, url to the dataset (word or raw level)
|
22 |
+
**kwargs: keyword arguments forwarded to super.
|
23 |
+
"""
|
24 |
+
super().__init__(
|
25 |
+
version=datasets.Version(
|
26 |
+
"1.0.0",
|
27 |
+
),
|
28 |
+
**kwargs,
|
29 |
+
)
|
30 |
+
self.data_url = data_url
|
31 |
+
|
32 |
+
|
33 |
+
class babyLM(datasets.GeneratorBasedBuilder):
|
34 |
+
"""TODO: Short description of dataset dataset."""
|
35 |
+
DATA_SOURCES = [
|
36 |
+
'aochildes', 'bnc_spoken', 'cbt', 'children_stories',
|
37 |
+
'gutenberg', 'open_subtitles', 'qed', 'simple_wikipedia',
|
38 |
+
'switchboard', 'wikipedia']
|
39 |
+
VERSION = datasets.Version("0.0.0")
|
40 |
+
BUILDER_CONFIGS = [
|
41 |
+
babyLMConfig(
|
42 |
+
name="babyLM-10M",
|
43 |
+
data_url=os.path.join(_DATA_URL, 'babylm_10M'),
|
44 |
+
description="Raw level dataset: the raw tokens before the addition of <unk> tokens. 10M tokens.",
|
45 |
+
),
|
46 |
+
babyLMConfig(
|
47 |
+
name="babyLM-100M",
|
48 |
+
data_url=os.path.join(_DATA_URL, 'babylm_100M'),
|
49 |
+
description="Raw level dataset: the raw tokens before the addition of <unk> tokens. 100M tokens.",
|
50 |
+
),
|
51 |
+
]
|
52 |
+
|
53 |
+
def _info(self):
|
54 |
+
return datasets.DatasetInfo(
|
55 |
+
# This is the description that will appear on the datasets page.
|
56 |
+
description=_DESCRIPTION,
|
57 |
+
# datasets.features.FeatureConnectors
|
58 |
+
features=datasets.Features(
|
59 |
+
{
|
60 |
+
"text": datasets.Value("string")
|
61 |
+
# These are the features of your dataset like images, labels ...
|
62 |
+
}
|
63 |
+
),
|
64 |
+
# If there's a common (input, target) tuple from the features,
|
65 |
+
# specify them here. They'll be used if as_supervised=True in
|
66 |
+
# builder.as_dataset.
|
67 |
+
supervised_keys=None,
|
68 |
+
homepage=_HOMEPAGE,
|
69 |
+
license=_LICENSE,
|
70 |
+
citation=_CITATION,
|
71 |
+
)
|
72 |
+
|
73 |
+
def _split_generators(self, dl_manager):
|
74 |
+
"""Returns SplitGenerators."""
|
75 |
+
ret_list = [
|
76 |
+
datasets.SplitGenerator(
|
77 |
+
name=datasets.Split.TEST,
|
78 |
+
gen_kwargs={"data_folder": os.path.join(_DATA_URL, "babylm_test"), "split": "test"},
|
79 |
+
),
|
80 |
+
datasets.SplitGenerator(
|
81 |
+
name=datasets.Split.VALIDATION,
|
82 |
+
gen_kwargs={"data_folder": os.path.join(_DATA_URL, "babylm_dev"), "split": "dev"},
|
83 |
+
),
|
84 |
+
datasets.SplitGenerator(
|
85 |
+
name=datasets.Split.TRAIN,
|
86 |
+
gen_kwargs={"data_folder": self.config.data_url, "split": "train"},
|
87 |
+
),
|
88 |
+
]
|
89 |
+
return ret_list
|
90 |
+
|
91 |
+
def _generate_examples(self, data_folder, split):
|
92 |
+
"""Yields examples."""
|
93 |
+
all_data_files = [
|
94 |
+
os.path.join(data_folder, f'{source}.{split}')
|
95 |
+
for source in self.DATA_SOURCES]
|
96 |
+
all_lines = []
|
97 |
+
for data_file in all_data_files:
|
98 |
+
with open(data_file, encoding="utf-8") as f:
|
99 |
+
all_lines.extend(f.readlines())
|
100 |
+
for idx, row in enumerate(all_lines):
|
101 |
+
if row.strip():
|
102 |
+
yield idx, {"text": row}
|
103 |
+
else:
|
104 |
+
yield idx, {"text": ""}
|
benchmarks/babylm/env/train.py
ADDED
@@ -0,0 +1,641 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
|
18 |
+
|
19 |
+
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
20 |
+
https://huggingface.co/models?filter=text-generation
|
21 |
+
"""
|
22 |
+
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import math
|
26 |
+
import os
|
27 |
+
# disable logging until training starts
|
28 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
29 |
+
import sys
|
30 |
+
from dataclasses import dataclass, field
|
31 |
+
from itertools import chain
|
32 |
+
from typing import Optional
|
33 |
+
|
34 |
+
import datasets
|
35 |
+
import evaluate
|
36 |
+
import torch
|
37 |
+
from datasets import load_dataset
|
38 |
+
|
39 |
+
import transformers
|
40 |
+
from transformers import (
|
41 |
+
CONFIG_MAPPING,
|
42 |
+
MODEL_FOR_CAUSAL_LM_MAPPING,
|
43 |
+
AutoConfig,
|
44 |
+
AutoModelForCausalLM,
|
45 |
+
AutoTokenizer,
|
46 |
+
HfArgumentParser,
|
47 |
+
Trainer,
|
48 |
+
TrainingArguments,
|
49 |
+
default_data_collator,
|
50 |
+
is_torch_tpu_available,
|
51 |
+
set_seed,
|
52 |
+
)
|
53 |
+
from transformers.testing_utils import CaptureLogger
|
54 |
+
from transformers.trainer_utils import get_last_checkpoint
|
55 |
+
from transformers.utils import check_min_version, send_example_telemetry
|
56 |
+
from transformers.utils.versions import require_version
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
logger = logging.getLogger(__name__)
|
61 |
+
|
62 |
+
|
63 |
+
MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
|
64 |
+
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
|
65 |
+
|
66 |
+
|
67 |
+
@dataclass
|
68 |
+
class ModelArguments:
|
69 |
+
"""
|
70 |
+
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
71 |
+
"""
|
72 |
+
|
73 |
+
model_name_or_path: Optional[str] = field(
|
74 |
+
default=None,
|
75 |
+
metadata={
|
76 |
+
"help": (
|
77 |
+
"The model checkpoint for weights initialization.Don't set if you want to train a model from scratch."
|
78 |
+
)
|
79 |
+
},
|
80 |
+
)
|
81 |
+
model_type: Optional[str] = field(
|
82 |
+
default="gpt2",
|
83 |
+
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
|
84 |
+
)
|
85 |
+
config_overrides: Optional[str] = field(
|
86 |
+
default=None,
|
87 |
+
metadata={
|
88 |
+
"help": (
|
89 |
+
"Override some existing default config settings when a model is trained from scratch. Example: "
|
90 |
+
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
91 |
+
)
|
92 |
+
},
|
93 |
+
)
|
94 |
+
config_name: Optional[str] = field(
|
95 |
+
default="gpt2", metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
96 |
+
)
|
97 |
+
tokenizer_name: Optional[str] = field(
|
98 |
+
default="gpt2", metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
|
99 |
+
)
|
100 |
+
cache_dir: Optional[str] = field(
|
101 |
+
default=None,
|
102 |
+
metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
|
103 |
+
)
|
104 |
+
use_fast_tokenizer: bool = field(
|
105 |
+
default=True,
|
106 |
+
metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
|
107 |
+
)
|
108 |
+
model_revision: str = field(
|
109 |
+
default="main",
|
110 |
+
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
|
111 |
+
)
|
112 |
+
use_auth_token: bool = field(
|
113 |
+
default=False,
|
114 |
+
metadata={
|
115 |
+
"help": (
|
116 |
+
"Will use the token generated when running `huggingface-cli login` (necessary to use this script "
|
117 |
+
"with private models)."
|
118 |
+
)
|
119 |
+
},
|
120 |
+
)
|
121 |
+
torch_dtype: Optional[str] = field(
|
122 |
+
default=None,
|
123 |
+
metadata={
|
124 |
+
"help": (
|
125 |
+
"Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
|
126 |
+
"dtype will be automatically derived from the model's weights."
|
127 |
+
),
|
128 |
+
"choices": ["auto", "bfloat16", "float16", "float32"],
|
129 |
+
},
|
130 |
+
)
|
131 |
+
low_cpu_mem_usage: bool = field(
|
132 |
+
default=False,
|
133 |
+
metadata={
|
134 |
+
"help": (
|
135 |
+
"It is an option to create the model as an empty shell, then only materialize its parameters when the pretrained weights are loaded."
|
136 |
+
"set True will benefit LLM loading time and RAM consumption."
|
137 |
+
)
|
138 |
+
},
|
139 |
+
)
|
140 |
+
|
141 |
+
def __post_init__(self):
|
142 |
+
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
|
143 |
+
raise ValueError(
|
144 |
+
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
@dataclass
|
149 |
+
class DataTrainingArguments:
|
150 |
+
"""
|
151 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
152 |
+
"""
|
153 |
+
|
154 |
+
dataset_name: Optional[str] = field(
|
155 |
+
default="babyLM_for_hf.py", metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
156 |
+
)
|
157 |
+
dataset_config_name: Optional[str] = field(
|
158 |
+
default="babyLM-10M", metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
|
159 |
+
)
|
160 |
+
train_file: Optional[str] = field(default=None, metadata={"help": "The input training data file (a text file)."})
|
161 |
+
validation_file: Optional[str] = field(
|
162 |
+
default=None,
|
163 |
+
metadata={"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."},
|
164 |
+
)
|
165 |
+
max_train_samples: Optional[int] = field(
|
166 |
+
default=None,
|
167 |
+
metadata={
|
168 |
+
"help": (
|
169 |
+
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
170 |
+
"value if set."
|
171 |
+
)
|
172 |
+
},
|
173 |
+
)
|
174 |
+
max_eval_samples: Optional[int] = field(
|
175 |
+
default=200,
|
176 |
+
metadata={
|
177 |
+
"help": (
|
178 |
+
"For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
179 |
+
"value if set."
|
180 |
+
)
|
181 |
+
},
|
182 |
+
)
|
183 |
+
streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
|
184 |
+
block_size: Optional[int] = field(
|
185 |
+
default=None,
|
186 |
+
metadata={
|
187 |
+
"help": (
|
188 |
+
"Optional input sequence length after tokenization. "
|
189 |
+
"The training dataset will be truncated in block of this size for training. "
|
190 |
+
"Default to the model max input length for single sentence inputs (take into account special tokens)."
|
191 |
+
)
|
192 |
+
},
|
193 |
+
)
|
194 |
+
overwrite_cache: bool = field(
|
195 |
+
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
|
196 |
+
)
|
197 |
+
validation_split_percentage: Optional[int] = field(
|
198 |
+
default=5,
|
199 |
+
metadata={
|
200 |
+
"help": "The percentage of the train set used as validation set in case there's no validation split"
|
201 |
+
},
|
202 |
+
)
|
203 |
+
preprocessing_num_workers: Optional[int] = field(
|
204 |
+
default=None,
|
205 |
+
metadata={"help": "The number of processes to use for the preprocessing."},
|
206 |
+
)
|
207 |
+
keep_linebreaks: bool = field(
|
208 |
+
default=True, metadata={"help": "Whether to keep line breaks when using TXT files or not."}
|
209 |
+
)
|
210 |
+
|
211 |
+
def __post_init__(self):
|
212 |
+
if self.streaming:
|
213 |
+
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")
|
214 |
+
|
215 |
+
if self.dataset_name is None and self.train_file is None and self.validation_file is None:
|
216 |
+
raise ValueError("Need either a dataset name or a training/validation file.")
|
217 |
+
else:
|
218 |
+
if self.train_file is not None:
|
219 |
+
extension = self.train_file.split(".")[-1]
|
220 |
+
assert extension in ["csv", "json", "txt"], "`train_file` should be a csv, a json or a txt file."
|
221 |
+
if self.validation_file is not None:
|
222 |
+
extension = self.validation_file.split(".")[-1]
|
223 |
+
assert extension in ["csv", "json", "txt"], "`validation_file` should be a csv, a json or a txt file."
|
224 |
+
|
225 |
+
|
226 |
+
def main():
|
227 |
+
# See all possible arguments in src/transformers/training_args.py
|
228 |
+
# or by passing the --help flag to this script.
|
229 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
230 |
+
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
|
231 |
+
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
232 |
+
# If we pass only one argument to the script and it's the path to a json file,
|
233 |
+
# let's parse it to get our arguments.
|
234 |
+
model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
|
235 |
+
else:
|
236 |
+
if "--output_dir" not in sys.argv:
|
237 |
+
sys.argv.append("--output_dir")
|
238 |
+
sys.argv.append("./output")
|
239 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
240 |
+
|
241 |
+
# by default we do both training and evaluation
|
242 |
+
training_args.do_train = True if not "--do_train" in sys.argv else training_args.do_train
|
243 |
+
training_args.do_eval = True if not "--do_eval" in sys.argv else training_args.do_eval
|
244 |
+
training_args.overwrite_output_dir = True if not "--overwrite_output_dir" in sys.argv else training_args.overwrite_output_dir
|
245 |
+
training_args.report_to = [] if not "--report_to" in sys.argv else training_args.report_to
|
246 |
+
training_args.log_level = "critical" if not "--log_level" in sys.argv else training_args.log_level
|
247 |
+
training_args.num_train_epochs = 1 if not "--num_train_epochs" in sys.argv else training_args.num_train_epochs
|
248 |
+
training_args.evaluation_strategy = "steps" if not "--evaluation_strategy" in sys.argv else training_args.evaluation_strategy
|
249 |
+
training_args.eval_steps = 0.2 if not "--eval_steps" in sys.argv else training_args.eval_steps
|
250 |
+
training_args.per_device_train_batch_size = 16 if not "--per_device_train_batch_size" in sys.argv else training_args.per_device_train_batch_size
|
251 |
+
|
252 |
+
# Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
|
253 |
+
# information sent is the one passed as arguments along with your Python/PyTorch versions.
|
254 |
+
send_example_telemetry("run_clm", model_args, data_args)
|
255 |
+
|
256 |
+
# Setup logging
|
257 |
+
logging.basicConfig(
|
258 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
259 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
260 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
261 |
+
)
|
262 |
+
|
263 |
+
if training_args.should_log:
|
264 |
+
# The default of training_args.log_level is passive, so we set log level at info here to have that default.
|
265 |
+
transformers.utils.logging.set_verbosity_info()
|
266 |
+
|
267 |
+
log_level = training_args.get_process_log_level()
|
268 |
+
logger.setLevel(log_level)
|
269 |
+
datasets.utils.logging.set_verbosity(log_level)
|
270 |
+
transformers.utils.logging.set_verbosity(log_level)
|
271 |
+
transformers.utils.logging.enable_default_handler()
|
272 |
+
transformers.utils.logging.enable_explicit_format()
|
273 |
+
|
274 |
+
# Log on each process the small summary:
|
275 |
+
logger.warning(
|
276 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
277 |
+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
278 |
+
)
|
279 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
280 |
+
|
281 |
+
# Detecting last checkpoint.
|
282 |
+
last_checkpoint = None
|
283 |
+
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
284 |
+
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
285 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
286 |
+
raise ValueError(
|
287 |
+
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
288 |
+
"Use --overwrite_output_dir to overcome."
|
289 |
+
)
|
290 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
291 |
+
logger.info(
|
292 |
+
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
293 |
+
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
294 |
+
)
|
295 |
+
|
296 |
+
# Set seed before initializing model.
|
297 |
+
set_seed(training_args.seed)
|
298 |
+
|
299 |
+
# Get the datasets: you can either provide your own CSV/JSON/TXT training and evaluation files (see below)
|
300 |
+
# or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
|
301 |
+
# (the dataset will be downloaded automatically from the datasets Hub).
|
302 |
+
#
|
303 |
+
# For CSV/JSON files, this script will use the column called 'text' or the first column if no column called
|
304 |
+
# 'text' is found. You can easily tweak this behavior (see below).
|
305 |
+
#
|
306 |
+
# In distributed training, the load_dataset function guarantee that only one local process can concurrently
|
307 |
+
# download the dataset.
|
308 |
+
if data_args.dataset_name is not None:
|
309 |
+
# Downloading and loading a dataset from the hub.
|
310 |
+
raw_datasets = load_dataset(
|
311 |
+
data_args.dataset_name,
|
312 |
+
data_args.dataset_config_name,
|
313 |
+
cache_dir=model_args.cache_dir,
|
314 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
315 |
+
streaming=data_args.streaming,
|
316 |
+
)
|
317 |
+
if "validation" not in raw_datasets.keys():
|
318 |
+
raw_datasets["validation"] = load_dataset(
|
319 |
+
data_args.dataset_name,
|
320 |
+
data_args.dataset_config_name,
|
321 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
322 |
+
cache_dir=model_args.cache_dir,
|
323 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
324 |
+
streaming=data_args.streaming,
|
325 |
+
)
|
326 |
+
raw_datasets["train"] = load_dataset(
|
327 |
+
data_args.dataset_name,
|
328 |
+
data_args.dataset_config_name,
|
329 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
330 |
+
cache_dir=model_args.cache_dir,
|
331 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
332 |
+
streaming=data_args.streaming,
|
333 |
+
)
|
334 |
+
else:
|
335 |
+
data_files = {}
|
336 |
+
dataset_args = {}
|
337 |
+
if data_args.train_file is not None:
|
338 |
+
data_files["train"] = data_args.train_file
|
339 |
+
if data_args.validation_file is not None:
|
340 |
+
data_files["validation"] = data_args.validation_file
|
341 |
+
extension = (
|
342 |
+
data_args.train_file.split(".")[-1]
|
343 |
+
if data_args.train_file is not None
|
344 |
+
else data_args.validation_file.split(".")[-1]
|
345 |
+
)
|
346 |
+
if extension == "txt":
|
347 |
+
extension = "text"
|
348 |
+
dataset_args["keep_linebreaks"] = data_args.keep_linebreaks
|
349 |
+
raw_datasets = load_dataset(
|
350 |
+
extension,
|
351 |
+
data_files=data_files,
|
352 |
+
cache_dir=model_args.cache_dir,
|
353 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
354 |
+
**dataset_args,
|
355 |
+
)
|
356 |
+
# If no validation data is there, validation_split_percentage will be used to divide the dataset.
|
357 |
+
if "validation" not in raw_datasets.keys():
|
358 |
+
raw_datasets["validation"] = load_dataset(
|
359 |
+
extension,
|
360 |
+
data_files=data_files,
|
361 |
+
split=f"train[:{data_args.validation_split_percentage}%]",
|
362 |
+
cache_dir=model_args.cache_dir,
|
363 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
364 |
+
**dataset_args,
|
365 |
+
)
|
366 |
+
raw_datasets["train"] = load_dataset(
|
367 |
+
extension,
|
368 |
+
data_files=data_files,
|
369 |
+
split=f"train[{data_args.validation_split_percentage}%:]",
|
370 |
+
cache_dir=model_args.cache_dir,
|
371 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
372 |
+
**dataset_args,
|
373 |
+
)
|
374 |
+
|
375 |
+
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
376 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
377 |
+
|
378 |
+
# Load pretrained model and tokenizer
|
379 |
+
#
|
380 |
+
# Distributed training:
|
381 |
+
# The .from_pretrained methods guarantee that only one local process can concurrently
|
382 |
+
# download model & vocab.
|
383 |
+
|
384 |
+
config_kwargs = {
|
385 |
+
"cache_dir": model_args.cache_dir,
|
386 |
+
"revision": model_args.model_revision,
|
387 |
+
"use_auth_token": True if model_args.use_auth_token else None,
|
388 |
+
}
|
389 |
+
if model_args.config_name:
|
390 |
+
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
|
391 |
+
elif model_args.model_name_or_path:
|
392 |
+
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
393 |
+
else:
|
394 |
+
config = CONFIG_MAPPING[model_args.model_type]()
|
395 |
+
logger.warning("You are instantiating a new config instance from scratch.")
|
396 |
+
if model_args.config_overrides is not None:
|
397 |
+
logger.info(f"Overriding config: {model_args.config_overrides}")
|
398 |
+
config.update_from_string(model_args.config_overrides)
|
399 |
+
logger.info(f"New config: {config}")
|
400 |
+
|
401 |
+
tokenizer_kwargs = {
|
402 |
+
"cache_dir": model_args.cache_dir,
|
403 |
+
"use_fast": model_args.use_fast_tokenizer,
|
404 |
+
"revision": model_args.model_revision,
|
405 |
+
"use_auth_token": True if model_args.use_auth_token else None,
|
406 |
+
}
|
407 |
+
if model_args.tokenizer_name:
|
408 |
+
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
|
409 |
+
elif model_args.model_name_or_path:
|
410 |
+
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)
|
411 |
+
else:
|
412 |
+
raise ValueError(
|
413 |
+
"You are instantiating a new tokenizer from scratch. This is not supported by this script."
|
414 |
+
"You can do it from another script, save it, and load it from here, using --tokenizer_name."
|
415 |
+
)
|
416 |
+
|
417 |
+
if model_args.model_name_or_path:
|
418 |
+
torch_dtype = (
|
419 |
+
model_args.torch_dtype
|
420 |
+
if model_args.torch_dtype in ["auto", None]
|
421 |
+
else getattr(torch, model_args.torch_dtype)
|
422 |
+
)
|
423 |
+
model = AutoModelForCausalLM.from_pretrained(
|
424 |
+
model_args.model_name_or_path,
|
425 |
+
from_tf=bool(".ckpt" in model_args.model_name_or_path),
|
426 |
+
config=config,
|
427 |
+
cache_dir=model_args.cache_dir,
|
428 |
+
revision=model_args.model_revision,
|
429 |
+
use_auth_token=True if model_args.use_auth_token else None,
|
430 |
+
torch_dtype=torch_dtype,
|
431 |
+
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
|
432 |
+
)
|
433 |
+
else:
|
434 |
+
model = AutoModelForCausalLM.from_config(config)
|
435 |
+
n_params = sum({p.data_ptr(): p.numel() for p in model.parameters()}.values())
|
436 |
+
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
|
437 |
+
|
438 |
+
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
|
439 |
+
# on a small vocab and want a smaller embedding size, remove this test.
|
440 |
+
embedding_size = model.get_input_embeddings().weight.shape[0]
|
441 |
+
if len(tokenizer) > embedding_size:
|
442 |
+
model.resize_token_embeddings(len(tokenizer))
|
443 |
+
|
444 |
+
# Preprocessing the datasets.
|
445 |
+
# First we tokenize all the texts.
|
446 |
+
if training_args.do_train:
|
447 |
+
column_names = list(raw_datasets["train"].features)
|
448 |
+
else:
|
449 |
+
column_names = list(raw_datasets["validation"].features)
|
450 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
451 |
+
|
452 |
+
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
|
453 |
+
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
|
454 |
+
|
455 |
+
def tokenize_function(examples):
|
456 |
+
with CaptureLogger(tok_logger) as cl:
|
457 |
+
output = tokenizer(examples[text_column_name])
|
458 |
+
# clm input could be much much longer than block_size
|
459 |
+
if "Token indices sequence length is longer than the" in cl.out:
|
460 |
+
tok_logger.warning(
|
461 |
+
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
|
462 |
+
" before being passed to the model."
|
463 |
+
)
|
464 |
+
return output
|
465 |
+
|
466 |
+
with training_args.main_process_first(desc="dataset map tokenization"):
|
467 |
+
if not data_args.streaming:
|
468 |
+
tokenized_datasets = raw_datasets.map(
|
469 |
+
tokenize_function,
|
470 |
+
batched=True,
|
471 |
+
num_proc=data_args.preprocessing_num_workers,
|
472 |
+
remove_columns=column_names,
|
473 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
474 |
+
desc="Running tokenizer on dataset",
|
475 |
+
)
|
476 |
+
else:
|
477 |
+
tokenized_datasets = raw_datasets.map(
|
478 |
+
tokenize_function,
|
479 |
+
batched=True,
|
480 |
+
remove_columns=column_names,
|
481 |
+
)
|
482 |
+
|
483 |
+
if data_args.block_size is None:
|
484 |
+
block_size = tokenizer.model_max_length
|
485 |
+
if block_size > 1024:
|
486 |
+
logger.warning(
|
487 |
+
"The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
|
488 |
+
" of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
|
489 |
+
" override this default with `--block_size xxx`."
|
490 |
+
)
|
491 |
+
block_size = 1024
|
492 |
+
else:
|
493 |
+
if data_args.block_size > tokenizer.model_max_length:
|
494 |
+
logger.warning(
|
495 |
+
f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
|
496 |
+
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
|
497 |
+
)
|
498 |
+
block_size = min(data_args.block_size, tokenizer.model_max_length)
|
499 |
+
|
500 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
501 |
+
def group_texts(examples):
|
502 |
+
# Concatenate all texts.
|
503 |
+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
504 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
505 |
+
# We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
|
506 |
+
# We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
|
507 |
+
total_length = (total_length // block_size) * block_size
|
508 |
+
# Split by chunks of max_len.
|
509 |
+
result = {
|
510 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
511 |
+
for k, t in concatenated_examples.items()
|
512 |
+
}
|
513 |
+
result["labels"] = result["input_ids"].copy()
|
514 |
+
return result
|
515 |
+
|
516 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
517 |
+
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
518 |
+
# to preprocess.
|
519 |
+
#
|
520 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
521 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
522 |
+
|
523 |
+
with training_args.main_process_first(desc="grouping texts together"):
|
524 |
+
if not data_args.streaming:
|
525 |
+
lm_datasets = tokenized_datasets.map(
|
526 |
+
group_texts,
|
527 |
+
batched=True,
|
528 |
+
num_proc=data_args.preprocessing_num_workers,
|
529 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
530 |
+
desc=f"Grouping texts in chunks of {block_size}",
|
531 |
+
)
|
532 |
+
else:
|
533 |
+
lm_datasets = tokenized_datasets.map(
|
534 |
+
group_texts,
|
535 |
+
batched=True,
|
536 |
+
)
|
537 |
+
|
538 |
+
if training_args.do_train:
|
539 |
+
if "train" not in tokenized_datasets:
|
540 |
+
raise ValueError("--do_train requires a train dataset")
|
541 |
+
train_dataset = lm_datasets["train"]
|
542 |
+
if data_args.max_train_samples is not None:
|
543 |
+
max_train_samples = min(len(train_dataset), data_args.max_train_samples)
|
544 |
+
train_dataset = train_dataset.select(range(max_train_samples))
|
545 |
+
|
546 |
+
if training_args.do_eval:
|
547 |
+
if "validation" not in tokenized_datasets:
|
548 |
+
raise ValueError("--do_eval requires a validation dataset")
|
549 |
+
eval_dataset = lm_datasets["validation"]
|
550 |
+
if data_args.max_eval_samples is not None:
|
551 |
+
max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
|
552 |
+
eval_dataset = eval_dataset.select(range(max_eval_samples))
|
553 |
+
|
554 |
+
def preprocess_logits_for_metrics(logits, labels):
|
555 |
+
if isinstance(logits, tuple):
|
556 |
+
# Depending on the model and config, logits may contain extra tensors,
|
557 |
+
# like past_key_values, but logits always come first
|
558 |
+
logits = logits[0]
|
559 |
+
return logits.argmax(dim=-1)
|
560 |
+
|
561 |
+
metric = evaluate.load("accuracy")
|
562 |
+
|
563 |
+
def compute_metrics(eval_preds):
|
564 |
+
preds, labels = eval_preds
|
565 |
+
# preds have the same shape as the labels, after the argmax(-1) has been calculated
|
566 |
+
# by preprocess_logits_for_metrics but we need to shift the labels
|
567 |
+
labels = labels[:, 1:].reshape(-1)
|
568 |
+
preds = preds[:, :-1].reshape(-1)
|
569 |
+
return metric.compute(predictions=preds, references=labels)
|
570 |
+
|
571 |
+
# Initialize our Trainer
|
572 |
+
trainer = Trainer(
|
573 |
+
model=model,
|
574 |
+
args=training_args,
|
575 |
+
train_dataset=train_dataset if training_args.do_train else None,
|
576 |
+
eval_dataset=eval_dataset if training_args.do_eval else None,
|
577 |
+
tokenizer=tokenizer,
|
578 |
+
# Data collator will default to DataCollatorWithPadding, so we change it.
|
579 |
+
data_collator=default_data_collator,
|
580 |
+
compute_metrics=compute_metrics if training_args.do_eval and not is_torch_tpu_available() else None,
|
581 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics
|
582 |
+
if training_args.do_eval and not is_torch_tpu_available()
|
583 |
+
else None,
|
584 |
+
)
|
585 |
+
|
586 |
+
transformers.utils.logging.set_verbosity(transformers.utils.logging.WARNING)
|
587 |
+
|
588 |
+
# Training
|
589 |
+
if training_args.do_train:
|
590 |
+
checkpoint = None
|
591 |
+
if training_args.resume_from_checkpoint is not None:
|
592 |
+
checkpoint = training_args.resume_from_checkpoint
|
593 |
+
elif last_checkpoint is not None:
|
594 |
+
checkpoint = last_checkpoint
|
595 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
596 |
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
597 |
+
|
598 |
+
metrics = train_result.metrics
|
599 |
+
|
600 |
+
max_train_samples = (
|
601 |
+
data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
|
602 |
+
)
|
603 |
+
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
604 |
+
|
605 |
+
trainer.log_metrics("train", metrics)
|
606 |
+
trainer.save_metrics("train", metrics)
|
607 |
+
trainer.save_state()
|
608 |
+
|
609 |
+
# Evaluation
|
610 |
+
if training_args.do_eval:
|
611 |
+
logger.info("*** Evaluate ***")
|
612 |
+
|
613 |
+
metrics = trainer.evaluate()
|
614 |
+
|
615 |
+
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
|
616 |
+
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
617 |
+
try:
|
618 |
+
perplexity = math.exp(metrics["eval_loss"])
|
619 |
+
except OverflowError:
|
620 |
+
perplexity = float("inf")
|
621 |
+
metrics["perplexity"] = perplexity
|
622 |
+
|
623 |
+
trainer.log_metrics("eval", metrics)
|
624 |
+
trainer.save_metrics("eval", metrics)
|
625 |
+
|
626 |
+
kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "text-generation"}
|
627 |
+
if data_args.dataset_name is not None:
|
628 |
+
kwargs["dataset_tags"] = data_args.dataset_name
|
629 |
+
if data_args.dataset_config_name is not None:
|
630 |
+
kwargs["dataset_args"] = data_args.dataset_config_name
|
631 |
+
kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
|
632 |
+
else:
|
633 |
+
kwargs["dataset"] = data_args.dataset_name
|
634 |
+
|
635 |
+
if training_args.push_to_hub:
|
636 |
+
trainer.push_to_hub(**kwargs)
|
637 |
+
else:
|
638 |
+
trainer.create_model_card(**kwargs)
|
639 |
+
|
640 |
+
if __name__ == "__main__":
|
641 |
+
main()
|
benchmarks/babylm/scripts/eval.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# coding=utf-8
|
3 |
+
# Copyright 2020 The HuggingFace Inc. team. All rights reserved.
|
4 |
+
#
|
5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6 |
+
# you may not use this file except in compliance with the License.
|
7 |
+
# You may obtain a copy of the License at
|
8 |
+
#
|
9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10 |
+
#
|
11 |
+
# Unless required by applicable law or agreed to in writing, software
|
12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14 |
+
# See the License for the specific language governing permissions and
|
15 |
+
# limitations under the License.
|
16 |
+
"""
|
17 |
+
Fine-tuning the library models for causal language modeling (GPT, GPT-2, CTRL, ...) on a text file or a dataset.
|
18 |
+
|
19 |
+
Here is the full list of checkpoints on the hub that can be fine-tuned by this script:
|
20 |
+
https://huggingface.co/models?filter=text-generation
|
21 |
+
"""
|
22 |
+
# You can also adapt this script on your own causal language modeling task. Pointers for this are left as comments.
|
23 |
+
|
24 |
+
import logging
|
25 |
+
import math
|
26 |
+
import os
|
27 |
+
# disable logging until training starts
|
28 |
+
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
|
29 |
+
import sys
|
30 |
+
from dataclasses import dataclass, field
|
31 |
+
from itertools import chain
|
32 |
+
from typing import Optional
|
33 |
+
|
34 |
+
import datasets
|
35 |
+
import evaluate
|
36 |
+
import torch
|
37 |
+
from datasets import load_dataset
|
38 |
+
|
39 |
+
import transformers
|
40 |
+
from transformers import (
|
41 |
+
CONFIG_MAPPING,
|
42 |
+
MODEL_FOR_CAUSAL_LM_MAPPING,
|
43 |
+
AutoConfig,
|
44 |
+
AutoModelForCausalLM,
|
45 |
+
AutoTokenizer,
|
46 |
+
HfArgumentParser,
|
47 |
+
Trainer,
|
48 |
+
TrainingArguments,
|
49 |
+
default_data_collator,
|
50 |
+
is_torch_tpu_available,
|
51 |
+
set_seed,
|
52 |
+
)
|
53 |
+
from transformers.testing_utils import CaptureLogger
|
54 |
+
from transformers.trainer_utils import get_last_checkpoint
|
55 |
+
from transformers.utils import check_min_version, send_example_telemetry
|
56 |
+
from transformers.utils.versions import require_version
|
57 |
+
|
58 |
+
from transformers import AutoModel, AutoTokenizer
|
59 |
+
from datasets import load_dataset
|
60 |
+
from transformers.testing_utils import CaptureLogger
|
61 |
+
|
62 |
+
from itertools import chain
|
63 |
+
|
64 |
+
logger = logging.getLogger(__name__)
|
65 |
+
|
66 |
+
|
67 |
+
def get_score(submission_folder = "../env"):
|
68 |
+
training_args = TrainingArguments("test_trainer")
|
69 |
+
training_args.report_to = []
|
70 |
+
raw_datasets = load_dataset(submission_folder + "/babyLM_for_hf.py", "babyLM-10M", split="test")
|
71 |
+
model = AutoModelForCausalLM.from_pretrained(submission_folder + "/output/")
|
72 |
+
tokenizer = AutoTokenizer.from_pretrained(submission_folder + "/output/")
|
73 |
+
|
74 |
+
# Preprocessing the datasets.
|
75 |
+
# First we tokenize all the texts.
|
76 |
+
column_names = list(raw_datasets.features)
|
77 |
+
text_column_name = "text" if "text" in column_names else column_names[0]
|
78 |
+
|
79 |
+
# since this will be pickled to avoid _LazyModule error in Hasher force logger loading before tokenize_function
|
80 |
+
tok_logger = transformers.utils.logging.get_logger("transformers.tokenization_utils_base")
|
81 |
+
|
82 |
+
def tokenize_function(examples):
|
83 |
+
with CaptureLogger(tok_logger) as cl:
|
84 |
+
output = tokenizer(examples[text_column_name])
|
85 |
+
# clm input could be much much longer than block_size
|
86 |
+
if "Token indices sequence length is longer than the" in cl.out:
|
87 |
+
tok_logger.warning(
|
88 |
+
"^^^^^^^^^^^^^^^^ Please ignore the warning above - this long input will be chunked into smaller bits"
|
89 |
+
" before being passed to the model."
|
90 |
+
)
|
91 |
+
return output
|
92 |
+
|
93 |
+
with training_args.main_process_first(desc="dataset map tokenization"):
|
94 |
+
# if not data_args.streaming:
|
95 |
+
# tokenized_datasets = raw_datasets.map(
|
96 |
+
# tokenize_function,
|
97 |
+
# batched=True,
|
98 |
+
# num_proc=data_args.preprocessing_num_workers,
|
99 |
+
# remove_columns=column_names,
|
100 |
+
# load_from_cache_file=not data_args.overwrite_cache,
|
101 |
+
# desc="Running tokenizer on dataset",
|
102 |
+
# )
|
103 |
+
# else:
|
104 |
+
tokenized_datasets = raw_datasets.map(
|
105 |
+
tokenize_function,
|
106 |
+
batched=True,
|
107 |
+
remove_columns=column_names,
|
108 |
+
)
|
109 |
+
|
110 |
+
if True:
|
111 |
+
block_size = tokenizer.model_max_length
|
112 |
+
if block_size > 1024:
|
113 |
+
logger.warning(
|
114 |
+
"The chosen tokenizer supports a `model_max_length` that is longer than the default `block_size` value"
|
115 |
+
" of 1024. If you would like to use a longer `block_size` up to `tokenizer.model_max_length` you can"
|
116 |
+
" override this default with `--block_size xxx`."
|
117 |
+
)
|
118 |
+
block_size = 1024
|
119 |
+
else:
|
120 |
+
if data_args.block_size > tokenizer.model_max_length:
|
121 |
+
logger.warning(
|
122 |
+
f"The block_size passed ({data_args.block_size}) is larger than the maximum length for the model"
|
123 |
+
f"({tokenizer.model_max_length}). Using block_size={tokenizer.model_max_length}."
|
124 |
+
)
|
125 |
+
block_size = min(data_args.block_size, tokenizer.model_max_length)
|
126 |
+
|
127 |
+
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
|
128 |
+
def group_texts(examples):
|
129 |
+
# Concatenate all texts.
|
130 |
+
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
|
131 |
+
total_length = len(concatenated_examples[list(examples.keys())[0]])
|
132 |
+
# We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict.
|
133 |
+
# We could add padding if the model supported it instead of this drop, you can customize this part to your needs.
|
134 |
+
total_length = (total_length // block_size) * block_size
|
135 |
+
# Split by chunks of max_len.
|
136 |
+
result = {
|
137 |
+
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
|
138 |
+
for k, t in concatenated_examples.items()
|
139 |
+
}
|
140 |
+
result["labels"] = result["input_ids"].copy()
|
141 |
+
return result
|
142 |
+
|
143 |
+
# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
|
144 |
+
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
|
145 |
+
# to preprocess.
|
146 |
+
#
|
147 |
+
# To speed up this part, we use multiprocessing. See the documentation of the map method for more information:
|
148 |
+
# https://huggingface.co/docs/datasets/package_reference/main_classes.html#datasets.Dataset.map
|
149 |
+
|
150 |
+
with training_args.main_process_first(desc="grouping texts together"):
|
151 |
+
# if not data_args.streaming:
|
152 |
+
# lm_datasets = tokenized_datasets.map(
|
153 |
+
# group_texts,
|
154 |
+
# batched=True,
|
155 |
+
# num_proc=data_args.preprocessing_num_workers,
|
156 |
+
# load_from_cache_file=not data_args.overwrite_cache,
|
157 |
+
# desc=f"Grouping texts in chunks of {block_size}",
|
158 |
+
# )
|
159 |
+
# else:
|
160 |
+
lm_datasets = tokenized_datasets.map(
|
161 |
+
group_texts,
|
162 |
+
batched=True,
|
163 |
+
)
|
164 |
+
eval_dataset = lm_datasets
|
165 |
+
|
166 |
+
def preprocess_logits_for_metrics(logits, labels):
|
167 |
+
if isinstance(logits, tuple):
|
168 |
+
# Depending on the model and config, logits may contain extra tensors,
|
169 |
+
# like past_key_values, but logits always come first
|
170 |
+
logits = logits[0]
|
171 |
+
return logits.argmax(dim=-1)
|
172 |
+
|
173 |
+
metric = evaluate.load("accuracy")
|
174 |
+
|
175 |
+
def compute_metrics(eval_preds):
|
176 |
+
preds, labels = eval_preds
|
177 |
+
# preds have the same shape as the labels, after the argmax(-1) has been calculated
|
178 |
+
# by preprocess_logits_for_metrics but we need to shift the labels
|
179 |
+
labels = labels[:, 1:].reshape(-1)
|
180 |
+
preds = preds[:, :-1].reshape(-1)
|
181 |
+
return metric.compute(predictions=preds, references=labels)
|
182 |
+
|
183 |
+
# Initialize our Trainer
|
184 |
+
trainer = Trainer(
|
185 |
+
model=model,
|
186 |
+
args=training_args,
|
187 |
+
train_dataset=None,
|
188 |
+
eval_dataset=eval_dataset,
|
189 |
+
tokenizer=tokenizer,
|
190 |
+
# Data collator will default to DataCollatorWithPadding, so we change it.
|
191 |
+
data_collator=default_data_collator,
|
192 |
+
compute_metrics=compute_metrics,
|
193 |
+
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
194 |
+
)
|
195 |
+
|
196 |
+
transformers.utils.logging.set_verbosity(transformers.utils.logging.WARNING)
|
197 |
+
|
198 |
+
# Evaluation
|
199 |
+
metrics = trainer.evaluate()
|
200 |
+
|
201 |
+
max_eval_samples = len(eval_dataset)
|
202 |
+
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
|
203 |
+
try:
|
204 |
+
perplexity = math.exp(metrics["eval_loss"])
|
205 |
+
except OverflowError:
|
206 |
+
perplexity = float("inf")
|
207 |
+
metrics["perplexity"] = perplexity
|
208 |
+
|
209 |
+
return perplexity
|
210 |
+
|
211 |
+
if __name__ == "__main__":
|
212 |
+
print(get_score())
|
benchmarks/babylm/scripts/prepare.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import subprocess
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
taskname = "babylm"
|
5 |
+
download_dir = "../env"
|
6 |
+
|
7 |
+
subprocess.run(["wget", "https://github.com/babylm/babylm.github.io/raw/main/babylm_data.zip"], cwd=download_dir)
|
8 |
+
subprocess.run(["unzip", "-n", f"babylm_data.zip"], cwd=download_dir)
|
9 |
+
subprocess.run(["rm", f"babylm_data.zip"], cwd=download_dir)
|
10 |
+
subprocess.run(["rm", "-rf", f"babylm_data/babylm_100M"], cwd=download_dir)
|
11 |
+
subprocess.run(["rm", "-rf", f"__MACOSX"], cwd=download_dir)
|
benchmarks/babylm/scripts/read_only_files.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
babylm_data/*
|
2 |
+
__MACOSX/*
|
benchmarks/babylm/scripts/research_problem.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Improve the baseline model performance on the babyLM Benchmark.
|
2 |
+
|
3 |
+
Summary: This shared task challenges community members to train a language model **from scratch** on the same amount of linguistic data available to a child. Submissions should be implemented in Huggingface's Transformers library and will be evaluated on a shared pipeline. This shared task is co-sponsored by CMCL and CoNLL.
|
4 |
+
|
5 |
+
To run the baseline model, execute train.py. It will train a standard gpt2 model on the babyLM data. The final model will be saved to output/ folder.
|
6 |
+
|
7 |
+
When you submit your final answer, you will be evaluated on the performance of the checkpoint saved in the output folder. It will be evaluated on a held-out test set.
|
benchmarks/bibtex-generation/env/arxiv_API_reference.txt
ADDED
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
arXiv API User's Manual
|
2 |
+
Please review the Terms of Use for arXiv APIs before using the arXiv API.
|
3 |
+
|
4 |
+
Table of Contents
|
5 |
+
1. Preface
|
6 |
+
2. API QuickStart
|
7 |
+
3. Structure of the API
|
8 |
+
3.1. Calling the API
|
9 |
+
3.1.1. Query Interface
|
10 |
+
3.1.1.1. search_query and id_list logic
|
11 |
+
3.1.1.2. start and max_results paging
|
12 |
+
3.1.1.3. sort order for return results
|
13 |
+
3.2. The API Response
|
14 |
+
3.3. Outline of an Atom feed
|
15 |
+
3.3.1. Feed Metadata
|
16 |
+
3.3.1.1. <title>, <id>, <link> and <updated>
|
17 |
+
3.3.1.2. OpenSearch Extension Elements
|
18 |
+
3.3.2. Entry Metadata
|
19 |
+
3.3.2.1. <title>, <id>, <published>, and <updated>
|
20 |
+
3.3.2.1. <summary>, <author> and <category>
|
21 |
+
3.3.2.3. <link>'s
|
22 |
+
3.3.2.4. <arxiv> extension elements
|
23 |
+
3.4. Errors
|
24 |
+
4. Examples
|
25 |
+
4.1. Simple Examples
|
26 |
+
4.1.1. Perl
|
27 |
+
4.1.2. Python
|
28 |
+
4.1.3. Ruby
|
29 |
+
4.1.4. PHP
|
30 |
+
4.2. Detailed Parsing Examples
|
31 |
+
5. Appendices
|
32 |
+
5.1. Details of Query Construction
|
33 |
+
5.1.1. A Note on Article Versions
|
34 |
+
5.2. Details of Atom Results Returned
|
35 |
+
5.3. Subject Classifications
|
36 |
+
|
37 |
+
|
38 |
+
1. Preface
|
39 |
+
The arXiv API allows programmatic access to the hundreds of thousands of e-prints hosted on arXiv.org.
|
40 |
+
|
41 |
+
This manual is meant to provide an introduction to using the API, as well as documentation describing its details, and as such is meant to be read by both beginning and advanced users. To get a flavor for how the API works, see the API Quickstart. For more detailed information, see Structure of the API.
|
42 |
+
|
43 |
+
For examples of using the API from several popular programming languages including perl, python and ruby, see the Examples section.
|
44 |
+
|
45 |
+
Finally, the Appendices contain an explanation of all input parameters to the API, as well as the output format.
|
46 |
+
|
47 |
+
|
48 |
+
2. API QuickStart
|
49 |
+
The easiest place to start with the API is by accessing it through a web browser. For examples of accessing the API through common programming languages, see the Examples section.
|
50 |
+
|
51 |
+
Most everyone that has read or submitted e-prints on the arXiv is familiar with the arXiv human web interface. These HTML pages can be accessed by opening up your web browser, and entering the following url in your web browser
|
52 |
+
|
53 |
+
http://arxiv.org
|
54 |
+
|
55 |
+
From there, the article listings can be browsed by clicking on one of the many links, or you can search for articles using the search box in the upper right hand side of the page. For example, if I wanted to search for articles that contain the word electron in the title or abstract, I would type electron in the search box, and click Go. If you follow my example, you will see something like this: a web page listing the title and authors of each result, with links to the abstract page, pdf, etc.
|
56 |
+
|
57 |
+
In its simplest form, the API can be used in exactly the same way. However, it uses a few shortcuts so there is less clicking involved. For example, you can see the same search results for electron by entering the url
|
58 |
+
|
59 |
+
http://export.arxiv.org/api/query?search_query=all:electron.
|
60 |
+
|
61 |
+
Alternatively, you can search for articles that contain electron AND proton with the API by entering
|
62 |
+
|
63 |
+
http://export.arxiv.org/api/query?search_query=all:electron+AND+all:proton
|
64 |
+
|
65 |
+
What you see will look different from the HTML interface, but it contains the same information as the search done with the human interface. The reason why the results look different is that the API returns results in the Atom 1.0 format, and not HTML. Since Atom is defined as an XML grammar, it is much easier to digest for programs than HTML. The API is not intended to be used inside a web browser by itself, but this is a particularly simple way to debug a program that does use the API.
|
66 |
+
|
67 |
+
You might notice that your web browser has asked you if you want to “subscribe to this feed” after you enter the API url. This is because Atom is one of the formats used by web sites to syndicate their content. These feeds are usually read with feed reader software, and are what is generated by the existing arXiv rss feeds. The current arXiv feeds only give you updates on new papers within the category you specify. One immediately useful thing to do with the API then is to generate your own feed, based on a custom query!
|
68 |
+
|
69 |
+
To learn more about how to construct custom search queries with the API, see the appendix on the details of query construction. To learn about what information is returned by the API, see the section on the API response. To learn more about writing programs to call the API, and digest the responses, we suggest starting with the section on Structure of the API.
|
70 |
+
|
71 |
+
|
72 |
+
3. Structure of the API
|
73 |
+
In this section, we'll go over some of the details of interacting with the API. A diagram of a typical API call is shown below:
|
74 |
+
|
75 |
+
Example: A typical API call
|
76 |
+
|
77 |
+
|
78 |
+
Request from url: http://export.arxiv.org/api/query (1)
|
79 |
+
with parameters: search_query=all:electron
|
80 |
+
.
|
81 |
+
.
|
82 |
+
.
|
83 |
+
API server processes the request and sends the response
|
84 |
+
.
|
85 |
+
.
|
86 |
+
.
|
87 |
+
Response received by client. (2)
|
88 |
+
The request can be made via HTTP GET, in which the parameters are encoded in the url, or via an HTTP POST in which the parameters are encoded in the HTTP request header. Most client libraries support both methods.
|
89 |
+
|
90 |
+
If all goes well, the HTTP header will show a 200 OK status, and the response body will contain the Atom response content as shown in the example response.
|
91 |
+
|
92 |
+
|
93 |
+
3.1. Calling the API
|
94 |
+
As mentioned above, the API can be called with an HTTP request of type GET or POST. For our purposes, the main difference is that the parameters are included in the url for a GET request, but not for the POST request. Thus if the parameters list is unusually long, a POST request might be preferred.
|
95 |
+
|
96 |
+
The parameters for each of the API methods are explained below. For each method, the base url is
|
97 |
+
|
98 |
+
|
99 |
+
http://export.arxiv.org/api/{method_name}?{parameters}
|
100 |
+
|
101 |
+
3.1.1. Query Interface
|
102 |
+
The API query interface has method_name=query. The table below outlines the parameters that can be passed to the query interface. Parameters are separated with the & sign in the constructed url's.
|
103 |
+
|
104 |
+
query
|
105 |
+
parameters type defaults required
|
106 |
+
search_query string None No
|
107 |
+
id_list comma-delimited string None No
|
108 |
+
start int 0 No
|
109 |
+
max_results int 10 No
|
110 |
+
|
111 |
+
3.1.1.1. SEARCH_QUERY AND ID_LIST LOGIC
|
112 |
+
We have already seen the use of search_query in the quickstart section. The search_query takes a string that represents a search query used to find articles. The construction of search_query is described in the search query construction appendix. The id_list contains a comma-delimited list of arXiv id's.
|
113 |
+
|
114 |
+
The logic of these two parameters is as follows:
|
115 |
+
|
116 |
+
If only search_query is given (id_list is blank or not given), then the API will return results for each article that matches the search query.
|
117 |
+
|
118 |
+
If only id_list is given (search_query is blank or not given), then the API will return results for each article in id_list.
|
119 |
+
|
120 |
+
If BOTH search_query and id_list are given, then the API will return each article in id_list that matches search_query. This allows the API to act as a results filter.
|
121 |
+
|
122 |
+
This is summarized in the following table:
|
123 |
+
|
124 |
+
search_query present id_list present API returns
|
125 |
+
yes no articles that match search_query
|
126 |
+
no yes articles that are in id_list
|
127 |
+
yes yes articles in id_list that also match search_query
|
128 |
+
|
129 |
+
3.1.1.2. START AND MAX_RESULTS PAGING
|
130 |
+
Many times there are hundreds of results for an API query. Rather than download information about all the results at once, the API offers a paging mechanism through start and max_results that allows you to download chucks of the result set at a time. Within the total results set, start defines the index of the first returned result, using 0-based indexing. max_results is the number of results returned by the query. For example, if wanted to step through the results of a search_query of all:electron, we would construct the urls:
|
131 |
+
|
132 |
+
|
133 |
+
http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=10 (1)
|
134 |
+
http://export.arxiv.org/api/query?search_query=all:electron&start=10&max_results=10 (2)
|
135 |
+
http://export.arxiv.org/api/query?search_query=all:electron&start=20&max_results=10 (3)
|
136 |
+
Get results 0-9
|
137 |
+
|
138 |
+
Get results 10-19
|
139 |
+
|
140 |
+
Get results 20-29
|
141 |
+
|
142 |
+
Detailed examples of how to perform paging in a variety of programming languages can be found in the examples section.
|
143 |
+
|
144 |
+
In cases where the API needs to be called multiple times in a row, we encourage you to play nice and incorporate a 3 second delay in your code. The detailed examples below illustrate how to do this in a variety of languages.
|
145 |
+
|
146 |
+
Because of speed limitations in our implementation of the API, the maximum number of results returned from a single call (max_results) is limited to 30000 in slices of at most 2000 at a time, using the max_results and start query parameters. For example to retrieve matches 6001-8000: http://export.arxiv.org/api/query?search_query=all:electron&start=6000&max_results=8000
|
147 |
+
|
148 |
+
Large result sets put considerable load on the server and also take a long time to render. We recommend to refine queries which return more than 1,000 results, or at least request smaller slices. For bulk metadata harvesting or set information, etc., the OAI-PMH interface is more suitable. A request with max_results >30,000 will result in an HTTP 400 error code with appropriate explanation. A request for 30000 results will typically take a little over 2 minutes to return a response of over 15MB. Requests for fewer results are much faster and correspondingly smaller.
|
149 |
+
|
150 |
+
|
151 |
+
3.1.1.3. SORT ORDER FOR RETURN RESULTS
|
152 |
+
There are two options for for the result set to the API search, sortBy and sortOrder.
|
153 |
+
|
154 |
+
sortBy can be "relevance", "lastUpdatedDate", "submittedDate"
|
155 |
+
|
156 |
+
sortOrder can be either "ascending" or "descending"
|
157 |
+
|
158 |
+
A sample query using these new parameters looks like:
|
159 |
+
|
160 |
+
|
161 |
+
http://export.arxiv.org/api/query?search_query=ti:"electron thermal conductivity"&sortBy=lastUpdatedDate&sortOrder=ascending
|
162 |
+
|
163 |
+
3.2. The API Response
|
164 |
+
Everything returned by the API in the body of the HTTP responses is Atom 1.0, including errors. Atom is a grammar of XML that is popular in the world of content syndication, and is very similar to RSS for this purpose. Typically web sites with dynamic content such as news sites and blogs will publish their content as Atom or RSS feeds. However, Atom is a general format that embodies the concept of a list of items, and thus is well-suited to returning the arXiv search results.
|
165 |
+
|
166 |
+
|
167 |
+
3.3. Outline of an Atom feed
|
168 |
+
In this section we will discuss the contents of the Atom documents returned by the API. To see the full explanation of the Atom 1.0 format, please see the Atom specification.
|
169 |
+
|
170 |
+
An API response consists of an Atom <feed> element which contains metadata about the API call performed, as well as child <entry> elements which embody the metadata for each of the returned results. Below we explain each of the elements and attributes. We will base our discussion on the sample results feed discussed in the examples section.
|
171 |
+
|
172 |
+
You may notice that the results from the API are ordered differently that the results given by the HTML arXiv search interface. The HTML interface automatically sorts results in descending order based on the date of their submission, while the API returns results according to relevancy from the internal search engine. Thus when debugging a search query, we encourage you to use the API within a web browser, rather than the HTML search interface. If you want sorting by date, you can always do this within your programs by reading the <published> tag for each entry as explained below.
|
173 |
+
|
174 |
+
|
175 |
+
3.3.1. Feed Metadata
|
176 |
+
Every response will contain the line:
|
177 |
+
|
178 |
+
|
179 |
+
<?xml version="1.0" encoding="utf-8"?>
|
180 |
+
to signify that we are receiving XML 1.0 with a UTF-8 encoding. Following that line will be a line indicating that we are receiving an Atom feed:
|
181 |
+
|
182 |
+
|
183 |
+
<feed xmlns="http://www.w3.org/2005/Atom"
|
184 |
+
xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/"
|
185 |
+
xmlns:arxiv="http://arxiv.org/schemas/atom">
|
186 |
+
You will notice that three XML namespaces are defined. The default namespace signifies that we are dealing with Atom 1.0. The other two namespaces define extensions to Atom that we describe below.
|
187 |
+
|
188 |
+
|
189 |
+
3.3.1.1. <TITLE>, <ID>, <LINK> AND <UPDATED>
|
190 |
+
The <title> element gives the title for the feed:
|
191 |
+
|
192 |
+
|
193 |
+
<title xmlns="http://www.w3.org/2005/Atom">
|
194 |
+
ArXiv Query: search_query=all:electron&id_list=&start=0&max_results=1
|
195 |
+
</title>
|
196 |
+
The title contains a canonicalized version of the query used to call the API. The canonicalization includes all parameters, using their defaults if they were not included, and always puts them in the order search_query,id_list,start,max_results, even if they were specified in a different order in the actual query.
|
197 |
+
|
198 |
+
The <id> element serves as a unique id for this query, and is useful if you are writing a program such as a feed reader that wants to keep track of all the feeds requested in the past. This id can then be used as a key in a database.
|
199 |
+
|
200 |
+
|
201 |
+
<id xmlns="http://www.w3.org/2005/Atom">
|
202 |
+
http://arxiv.org/api/cHxbiOdZaP56ODnBPIenZhzg5f8
|
203 |
+
</id>
|
204 |
+
The id is guaranteed to be unique for each query.
|
205 |
+
|
206 |
+
The <link> element provides a URL that can be used to retrieve this feed again.
|
207 |
+
|
208 |
+
|
209 |
+
<link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/api/query?search_query=all:electron&id_list=&start=0&max_results=1" rel="self" type="application/atom+xml"/>
|
210 |
+
Note that the url in the link represents the canonicalized version of the query. The <link> provides a GET requestable url, even if the original request was done via POST.
|
211 |
+
|
212 |
+
The <updated> element provides the last time the contents of the feed were last updated:
|
213 |
+
|
214 |
+
|
215 |
+
<updated xmlns="http://www.w3.org/2005/Atom">2007-10-08T00:00:00-04:00</updated>
|
216 |
+
Because the arXiv submission process works on a 24 hour submission cycle, new articles are only available to the API on the midnight after the articles were processed. The <updated> tag thus reflects the midnight of the day that you are calling the API. This is very important - search results do not change until new articles are added. Therefore there is no need to call the API more than once in a day for the same query. Please cache your results. This primarily applies to production systems, and of course you are free to play around with the API while you are developing your program!
|
217 |
+
|
218 |
+
|
219 |
+
3.3.1.2. OPENSEARCH EXTENSION ELEMENTS
|
220 |
+
There are several extension elements defined in the OpenSearch namespace
|
221 |
+
|
222 |
+
|
223 |
+
http://a9.com/-/spec/opensearch/1.1/
|
224 |
+
OpenSearch is a lightweight technology that acts in a similar way as the Web Services Description Language. The OpenSearch elements we have included allow OpenSearch enabled clients to digest our results. Such clients often include search result aggregators and browser pluggins that allow searching from a variety of sources.
|
225 |
+
|
226 |
+
The OpenSearch extension elements can still be useful to you even if you are not writing one of these applications. The <opensearch:totalResults> element lists how many results are in the result set for the query:
|
227 |
+
|
228 |
+
|
229 |
+
<opensearch:totalResults xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">
|
230 |
+
1000
|
231 |
+
</opensearch:totalResults>
|
232 |
+
This can be very useful when implementing paging of search results. The other two elements <opensearch:startIndex>, and <opensearch:itemsPerPage> are analogous to start, and max_results discussed above.
|
233 |
+
|
234 |
+
|
235 |
+
<opensearch:startIndex xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">
|
236 |
+
0
|
237 |
+
</opensearch:startIndex>
|
238 |
+
<opensearch:itemsPerPage xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">
|
239 |
+
1
|
240 |
+
</opensearch:itemsPerPage>
|
241 |
+
|
242 |
+
3.3.2. Entry Metadata
|
243 |
+
If there are no errors, the <feed> element contains 0 or more child <entry> elements with each <entry> representing an article in the returned results set. As explained in the errors section, if there are errors, a single <entry> element representing the error is returned. Below the element description describes the elements for <entry>'s representing arXiv articles. For a general discussion of arXiv metadata, see the arXiv metadata explanation.
|
244 |
+
|
245 |
+
|
246 |
+
3.3.2.1. <TITLE>, <ID>, <PUBLISHED>, AND <UPDATED>
|
247 |
+
The <title> element contains the title of the article returned:
|
248 |
+
|
249 |
+
|
250 |
+
<title xmlns="http://www.w3.org/2005/Atom">
|
251 |
+
Multi-Electron Production at High Transverse Momenta in ep Collisions at HERA
|
252 |
+
</title>
|
253 |
+
The <id> element contains a url that resolves to the abstract page for that article:
|
254 |
+
|
255 |
+
|
256 |
+
<id xmlns="http://www.w3.org/2005/Atom">
|
257 |
+
http://arxiv.org/abs/hep-ex/0307015
|
258 |
+
</id>
|
259 |
+
If you want only the arXiv id for the article, you can remove the leading http://arxiv.org/abs/ in the <id>.
|
260 |
+
|
261 |
+
The <published> tag contains the date in which the first version of this article was submitted and processed. The <updated> element contains the date on which the retrieved article was submitted and processed. If the version is version 1, then <published> == <updated>, otherwise they are different. In the example below, the article retrieved was version 2, so <updated> and <published> are different (see the original query).
|
262 |
+
|
263 |
+
|
264 |
+
<published xmlns="http://www.w3.org/2005/Atom">
|
265 |
+
2007-02-27T16:02:02-05:00
|
266 |
+
</published>
|
267 |
+
<updated xmlns="http://www.w3.org/2005/Atom">
|
268 |
+
2007-06-25T17:09:59-04:00
|
269 |
+
</updated>
|
270 |
+
|
271 |
+
3.3.2.2. <SUMMARY>, <AUTHOR> AND <CATEGORY>
|
272 |
+
The <summary> element contains the abstract for the article:
|
273 |
+
|
274 |
+
|
275 |
+
<summary xmlns="http://www.w3.org/2005/Atom">
|
276 |
+
Multi-electron production is studied at high electron transverse momentum
|
277 |
+
in positron- and electron-proton collisions using the H1 detector at HERA.
|
278 |
+
The data correspond to an integrated luminosity of 115 pb-1. Di-electron
|
279 |
+
and tri-electron event yields are measured. Cross sections are derived in
|
280 |
+
a restricted phase space region dominated by photon-photon collisions. In
|
281 |
+
general good agreement is found with the Standard Model predictions.
|
282 |
+
However, for electron pair invariant masses above 100 GeV, three
|
283 |
+
di-electron events and three tri-electron events are observed, compared to
|
284 |
+
Standard Model expectations of 0.30 \pm 0.04 and 0.23 \pm 0.04,
|
285 |
+
respectively.
|
286 |
+
</summary>
|
287 |
+
There is one <author> element for each author of the paper in order of authorship. Each <author> element has a <name> sub-element which contains the name of the author.
|
288 |
+
|
289 |
+
|
290 |
+
<author xmlns="http://www.w3.org/2005/Atom">
|
291 |
+
<name xmlns="http://www.w3.org/2005/Atom">H1 Collaboration</name>
|
292 |
+
</author>
|
293 |
+
If author affiliation is present, it is included as an <arxiv:affiliation> subelement of the <author> element as discussed below.
|
294 |
+
|
295 |
+
The <category> element is used to describe either an arXiv, ACM, or MSC classification. See the arXiv metadata explanation for more details about these classifications. The <category> element has two attributes, scheme, which is the categorization scheme, and term which is the term used in the categorization. Here is an example from the query http://export.arxiv.org/api/query?id_list=cs/9901002v1
|
296 |
+
|
297 |
+
|
298 |
+
<category xmlns="http://www.w3.org/2005/Atom" term="cs.LG" scheme="http://arxiv.org/schemas/atom"/>
|
299 |
+
<category xmlns="http://www.w3.org/2005/Atom" term="cs.AI" scheme="http://arxiv.org/schemas/atom"/>
|
300 |
+
<category xmlns="http://www.w3.org/2005/Atom" term="I.2.6" scheme="http://arxiv.org/schemas/atom"/>
|
301 |
+
Note that in this example, there are 3 category elements, one for each category. The first two correspond to arXiv categories, and the last one to an ACM category. See <arxiv> extension elements below for information on how to identify the arXiv primary category.
|
302 |
+
|
303 |
+
|
304 |
+
3.3.2.3. <LINK>'S
|
305 |
+
For each entry, there are up to three <link> elements, distinguished by their rel and title attributes. The table below summarizes what these links refer to
|
306 |
+
|
307 |
+
rel title refers to always present
|
308 |
+
alternate - abstract page yes
|
309 |
+
related pdf pdf yes
|
310 |
+
related doi resolved doi no
|
311 |
+
For example:
|
312 |
+
|
313 |
+
|
314 |
+
<link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/abs/hep-ex/0307015v1" rel="alternate" type="text/html"/>
|
315 |
+
<link xmlns="http://www.w3.org/2005/Atom" title="pdf" href="http://arxiv.org/pdf/hep-ex/0307015v1" rel="related" type="application/pdf"/>
|
316 |
+
<link xmlns="http://www.w3.org/2005/Atom" title="doi" href="http://dx.doi.org/10.1529/biophysj.104.047340" rel="related"/>
|
317 |
+
|
318 |
+
3.3.2.4. <ARXIV> EXTENSION ELEMENTS
|
319 |
+
There are several pieces of arXiv metadata that are not able to be mapped onto the standard Atom specification. We have therefore defined several extension elements which live in the arxiv namespace
|
320 |
+
|
321 |
+
|
322 |
+
http://arxiv.org/schemas/atom
|
323 |
+
The arXiv classification system supports multiple <category> tags, as well as a primary classification. The primary classification is a replica of an Atom <category> tag, except it has the name <arxiv:primary_category>. For example, from the query http://export.arxiv.org/api/query?id_list=cs/9901002v1, we have
|
324 |
+
|
325 |
+
|
326 |
+
<arxiv:primary_category xmlns:arxiv="http://arxiv.org/schemas/atom" term="cs.LG" scheme="http://arxiv.org/schemas/atom"/>
|
327 |
+
signifying that cs.LG is the primary arXiv classification for this e-print.
|
328 |
+
|
329 |
+
The <arxiv:comment> element contains the typical author comments found on most arXiv articles:
|
330 |
+
|
331 |
+
|
332 |
+
<arxiv:comment xmlns:arxiv="http://arxiv.org/schemas/atom">
|
333 |
+
23 pages, 8 figures and 4 tables
|
334 |
+
</arxiv:comment>
|
335 |
+
If the author has supplied affiliation information, then this is included as an <arxiv:affiliation> subelement of the standard Atom <author> element. For example, from the query http://export.arxiv.org/api/query?id_list=0710.5765v1, we have
|
336 |
+
|
337 |
+
|
338 |
+
<author>
|
339 |
+
<name>G. G. Kacprzak</name>
|
340 |
+
<arxiv:affiliation xmlns:arxiv="http://arxiv.org/schemas/atom">NMSU</arxiv:affiliation>
|
341 |
+
</author>
|
342 |
+
If the author has provided a journal reference for the article, then there will be a <arxiv:journal_ref> element with this information:
|
343 |
+
|
344 |
+
|
345 |
+
<arxiv:journal_ref xmlns:arxiv="http://arxiv.org/schemas/atom">
|
346 |
+
Eur.Phys.J. C31 (2003) 17-29
|
347 |
+
</arxiv:journal_ref>
|
348 |
+
If the author has provided a DOI for the article, then there will be a <arxiv:doi> element with this information:
|
349 |
+
|
350 |
+
|
351 |
+
<arxiv:doi xmlns:arxiv="http://arxiv.org/schemas/atom">
|
352 |
+
10.1529/biophysj.104.047340
|
353 |
+
</arxiv:doi>
|
354 |
+
|
355 |
+
3.4. Errors
|
356 |
+
Errors are returned as Atom feeds with a single entry representing the error. The <summary> for the error contains a helpful error message, and the <link> element contains a url to a more detailed explanation of the message.
|
357 |
+
|
358 |
+
For example, the API call http://export.arxiv.org/api/query?id_list=1234.12345 contains a malformed id, and results in the error
|
359 |
+
|
360 |
+
|
361 |
+
<?xml version="1.0" encoding="utf-8"?>
|
362 |
+
<feed xmlns="http://www.w3.org/2005/Atom" xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">
|
363 |
+
<link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/api/query?search_query=&id_list=1234.12345" rel="self" type="application/atom+xml"/>
|
364 |
+
<title xmlns="http://www.w3.org/2005/Atom">ArXiv Query: search_query=&id_list=1234.12345</title>
|
365 |
+
<id xmlns="http://www.w3.org/2005/Atom">http://arxiv.org/api/kvuntZ8c9a4Eq5CF7KY03nMug+Q</id>
|
366 |
+
<updated xmlns="http://www.w3.org/2005/Atom">2007-10-12T00:00:00-04:00</updated>
|
367 |
+
<opensearch:totalResults xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">1</opensearch:totalResults>
|
368 |
+
<opensearch:startIndex xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">0</opensearch:startIndex>
|
369 |
+
|
370 |
+
<opensearch:itemsPerPage xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">1</opensearch:itemsPerPage>
|
371 |
+
<entry xmlns="http://www.w3.org/2005/Atom">
|
372 |
+
<id xmlns="http://www.w3.org/2005/Atom">http://arxiv.org/api/errors#incorrect_id_format_for_1234.12345</id>
|
373 |
+
<title xmlns="http://www.w3.org/2005/Atom">Error</title>
|
374 |
+
<summary xmlns="http://www.w3.org/2005/Atom">incorrect id format for 1234.12345</summary>
|
375 |
+
<updated xmlns="http://www.w3.org/2005/Atom">2007-10-12T00:00:00-04:00</updated>
|
376 |
+
|
377 |
+
<link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/api/errors#incorrect_id_format_for_1234.12345" rel="alternate" type="text/html"/>
|
378 |
+
<author xmlns="http://www.w3.org/2005/Atom">
|
379 |
+
<name xmlns="http://www.w3.org/2005/Atom">arXiv api core</name>
|
380 |
+
</author>
|
381 |
+
</entry>
|
382 |
+
</feed>
|
383 |
+
The following table gives information on errors that might occur.
|
384 |
+
|
385 |
+
Sample query Error Explanation
|
386 |
+
http://export.arxiv.org/api/query?start=not_an_int start must be an integer
|
387 |
+
http://export.arxiv.org/api/query?start=-1 start must be >= 0
|
388 |
+
http://export.arxiv.org/api/query?max_results=not_an_int max_results must be an integer
|
389 |
+
http://export.arxiv.org/api/query?max_results=-1 max_results must be >= 0
|
390 |
+
http://export.arxiv.org/api/query?id_list=1234.1234 malformed id - see arxiv identifier explanation
|
391 |
+
http://export.arxiv.org/api/query?id_list=cond—mat/0709123 malformed id - see arxiv identifier explanation
|
392 |
+
|
393 |
+
4. Examples
|
394 |
+
Once you have familiarized yourself with the API, you should be able to easily write programs that call the API automatically. Most programming languages, if not all, have libraries that allow you to make HTTP requests. Since Atom is growing, not all languages have libraries that support Atom parsing, so most of the programming effort will be in digesting the responses you receive. The languages that we know of that can easily handle calling the api via HTTP and parsing the results include:
|
395 |
+
|
396 |
+
Perl (via LWP) (example)
|
397 |
+
|
398 |
+
Python (via urllib) (example)
|
399 |
+
|
400 |
+
Ruby (via uri and net::http) (example)
|
401 |
+
|
402 |
+
PHP (via file_get_contents()) (example)
|
403 |
+
|
404 |
+
|
405 |
+
4.1. Simple Examples
|
406 |
+
Below we include code snippets for these languages that perform the bare minimum functionality - calling the api and printing the raw Atom results. If your favorite language is not up here, write us with an example, and we'll be glad to post it!
|
407 |
+
|
408 |
+
All of the simple examples produce an output which looks like:
|
409 |
+
|
410 |
+
Example: A Typical Atom Response
|
411 |
+
|
412 |
+
|
413 |
+
<?xml version="1.0" encoding="utf-8"?>
|
414 |
+
<feed xmlns="http://www.w3.org/2005/Atom" xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/" xmlns:arxiv="http://arxiv.org/schemas/atom">
|
415 |
+
<link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/api/query?search_query=all:electron&id_list=&start=0&max_results=1" rel="self" type="application/atom+xml"/>
|
416 |
+
<title xmlns="http://www.w3.org/2005/Atom">ArXiv Query: search_query=all:electron&id_list=&start=0&max_results=1</title>
|
417 |
+
<id xmlns="http://www.w3.org/2005/Atom">http://arxiv.org/api/cHxbiOdZaP56ODnBPIenZhzg5f8</id>
|
418 |
+
<updated xmlns="http://www.w3.org/2005/Atom">2007-10-08T00:00:00-04:00</updated>
|
419 |
+
<opensearch:totalResults xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">1000</opensearch:totalResults>
|
420 |
+
<opensearch:startIndex xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">0</opensearch:startIndex>
|
421 |
+
<opensearch:itemsPerPage xmlns:opensearch="http://a9.com/-/spec/opensearch/1.1/">1</opensearch:itemsPerPage>
|
422 |
+
<entry xmlns="http://www.w3.org/2005/Atom" xmlns:arxiv="http://arxiv.org/schemas/atom">
|
423 |
+
<id xmlns="http://www.w3.org/2005/Atom">http://arxiv.org/abs/hep-ex/0307015</id>
|
424 |
+
<published xmlns="http://www.w3.org/2005/Atom">2003-07-07T13:46:39-04:00</published>
|
425 |
+
<updated xmlns="http://www.w3.org/2005/Atom">2003-07-07T13:46:39-04:00</updated>
|
426 |
+
<title xmlns="http://www.w3.org/2005/Atom">Multi-Electron Production at High Transverse Momenta in ep Collisions at
|
427 |
+
HERA</title>
|
428 |
+
<summary xmlns="http://www.w3.org/2005/Atom"> Multi-electron production is studied at high electron transverse momentum in
|
429 |
+
positron- and electron-proton collisions using the H1 detector at HERA. The
|
430 |
+
data correspond to an integrated luminosity of 115 pb-1. Di-electron and
|
431 |
+
tri-electron event yields are measured. Cross sections are derived in a
|
432 |
+
restricted phase space region dominated by photon-photon collisions. In general
|
433 |
+
good agreement is found with the Standard Model predictions. However, for
|
434 |
+
electron pair invariant masses above 100 GeV, three di-electron events and
|
435 |
+
three tri-electron events are observed, compared to Standard Model expectations
|
436 |
+
of 0.30 \pm 0.04 and 0.23 \pm 0.04, respectively.
|
437 |
+
</summary>
|
438 |
+
<author xmlns="http://www.w3.org/2005/Atom">
|
439 |
+
<name xmlns="http://www.w3.org/2005/Atom">H1 Collaboration</name>
|
440 |
+
</author>
|
441 |
+
<arxiv:comment xmlns:arxiv="http://arxiv.org/schemas/atom">23 pages, 8 figures and 4 tables</arxiv:comment>
|
442 |
+
<arxiv:journal_ref xmlns:arxiv="http://arxiv.org/schemas/atom">Eur.Phys.J. C31 (2003) 17-29</arxiv:journal_ref>
|
443 |
+
<link xmlns="http://www.w3.org/2005/Atom" href="http://arxiv.org/abs/hep-ex/0307015v1" rel="alternate" type="text/html"/>
|
444 |
+
<link xmlns="http://www.w3.org/2005/Atom" title="pdf" href="http://arxiv.org/pdf/hep-ex/0307015v1" rel="related" type="application/pdf"/>
|
445 |
+
<arxiv:primary_category xmlns:arxiv="http://arxiv.org/schemas/atom" term="hep-ex" scheme="http://arxiv.org/schemas/atom"/>
|
446 |
+
<category term="hep-ex" scheme="http://arxiv.org/schemas/atom"/>
|
447 |
+
</entry>
|
448 |
+
</feed>
|
449 |
+
|
450 |
+
4.1.1. Perl
|
451 |
+
LWP is in the default perl installation on most platforms. It can be downloaded and installed from CPAN. Sample code to produce the above output is:
|
452 |
+
|
453 |
+
|
454 |
+
use LWP;
|
455 |
+
use strict;
|
456 |
+
|
457 |
+
my $url = 'http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=1';
|
458 |
+
my $browser = LWP::UserAgent->new();
|
459 |
+
my $response = $browser->get($url);
|
460 |
+
print $response->content();
|
461 |
+
|
462 |
+
4.1.2. Python
|
463 |
+
The urllib module is part of the python standard library, and is included in any default installation of python. Sample code to produce the above output in Python 2.7 is:
|
464 |
+
|
465 |
+
|
466 |
+
import urllib
|
467 |
+
url = 'http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=1'
|
468 |
+
data = urllib.urlopen(url).read()
|
469 |
+
print data
|
470 |
+
wheras in Python 3 an example would be:
|
471 |
+
|
472 |
+
|
473 |
+
import urllib.request as libreq
|
474 |
+
with libreq.urlopen('http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=1') as url:
|
475 |
+
r = url.read()
|
476 |
+
print(r)
|
477 |
+
|
478 |
+
4.1.3. Ruby
|
479 |
+
The net/http and uri modules are part of the ruby standard library, and are included in any default installation of ruby. Sample code to produce the above output is:
|
480 |
+
|
481 |
+
|
482 |
+
require 'net/http'
|
483 |
+
require 'uri'
|
484 |
+
url = URI.parse('http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=1')
|
485 |
+
res = Net::HTTP.get_response(url)
|
486 |
+
print res.body
|
487 |
+
|
488 |
+
4.1.4. PHP
|
489 |
+
The file_get_contents() function is part of the PHP core language:
|
490 |
+
|
491 |
+
|
492 |
+
<?php
|
493 |
+
$url = 'http://export.arxiv.org/api/query?search_query=all:electron&start=0&max_results=1';
|
494 |
+
$response = file_get_contents($url);
|
495 |
+
print_r($response);
|
496 |
+
?>
|
497 |
+
|
498 |
+
4.2. Detailed Parsing Examples
|
499 |
+
The examples above don't cover how to parse the Atom results returned to extract the information you might be interested in. They also don't cover how to do more advanced programming of the API to perform such tasks as downloading chunks of the full results list one page at a time. The table below contains links to more detailed examples for each of the languages above, as well as to the libraries used to parse Atom.
|
500 |
+
|
501 |
+
Language Library Parsing Example Paging Example
|
502 |
+
Perl XML::Atom parsing paging
|
503 |
+
Python feedparser parsing paging
|
504 |
+
Ruby feedtools parsing paging
|
505 |
+
PHP SimplePie parsing paging
|
506 |
+
|
507 |
+
5. Appendices
|
508 |
+
|
509 |
+
5.1. Details of Query Construction
|
510 |
+
As outlined in the Structure of the API section, the interface to the API is quite simple. This simplicity, combined with search_query construction, and result set filtering through id_list makes the API a powerful tool for harvesting data from the arXiv. In this section, we outline the possibilities for constructing search_query's to retrieve our desired article lists. We outlined how to use the id_list parameter to filter results sets in search_query and id_list logic.
|
511 |
+
|
512 |
+
In the arXiv search engine, each article is divided up into a number of fields that can individually be searched. For example, the titles of an article can be searched, as well as the author list, abstracts, comments and journal reference. To search one of these fields, we simply prepend the field prefix followed by a colon to our search term. For example, suppose we wanted to find all articles by the author Adrian Del Maestro. We could construct the following query
|
513 |
+
|
514 |
+
http://export.arxiv.org/api/query?search_query=au:del_maestro
|
515 |
+
|
516 |
+
This returns nine results. The following table lists the field prefixes for all the fields that can be searched.
|
517 |
+
|
518 |
+
prefix explanation
|
519 |
+
ti Title
|
520 |
+
au Author
|
521 |
+
abs Abstract
|
522 |
+
co Comment
|
523 |
+
jr Journal Reference
|
524 |
+
cat Subject Category
|
525 |
+
rn Report Number
|
526 |
+
id Id (use id_list instead)
|
527 |
+
all All of the above
|
528 |
+
Note: The id_list parameter should be used rather than search_query=id:xxx to properly handle article versions. In addition, note that all: searches in each of the fields simultaneously.
|
529 |
+
|
530 |
+
The API allows advanced query construction by combining these search fields with Boolean operators. For example, suppose we want to find all articles by the author Adrian DelMaestro that also contain the word checkerboard in the title. We could construct the following query, using the AND operator:
|
531 |
+
|
532 |
+
http://export.arxiv.org/api/query?search_query=au:del_maestro+AND+ti:checkerboard
|
533 |
+
|
534 |
+
As expected, this query picked out the one of the nine previous results with checkerboard in the title. Note that we included + signs in the urls to the API. In a url, a + sign encodes a space, which is useful since spaces are not allowed in url's. It is always a good idea to escape the characters in your url's, which is a common feature in most programming libraries that deal with url's. Note that the <title> of the returned feed has spaces in the query constructed. It is a good idea to look at <title> to see if you have escaped your url correctly.
|
535 |
+
|
536 |
+
The following table lists the three possible Boolean operators.
|
537 |
+
|
538 |
+
AND
|
539 |
+
OR
|
540 |
+
ANDNOT
|
541 |
+
The ANDNOT Boolean operator is particularly useful, as it allows us to filter search results based on certain fields. For example, if we wanted all of the articles by the author Adrian DelMaestro with titles that did not contain the word checkerboard, we could construct the following query:
|
542 |
+
|
543 |
+
http://export.arxiv.org/api/query?search_query=au:del_maestro+ANDNOT+ti:checkerboard
|
544 |
+
|
545 |
+
As expected, this query returns eight results.
|
546 |
+
|
547 |
+
Finally, even more complex queries can be used by using parentheses for grouping the Boolean expressions. To include parentheses in in a url, use %28 for a left-parens (, and %29 for a right-parens ). For example, if we wanted all of the articles by the author Adrian DelMaestro with titles that did not contain the words checkerboard, OR Pyrochore, we could construct the following query:
|
548 |
+
|
549 |
+
http://export.arxiv.org/api/query?search_query=au:del_maestro+ANDNOT+%28ti:checkerboard+OR+ti:Pyrochlore%29
|
550 |
+
|
551 |
+
This query returns three results. Notice that the <title> element displays the parenthesis correctly meaning that we used the correct url escaping.
|
552 |
+
|
553 |
+
So far we have only used single words as the field terms to search for. You can include entire phrases by enclosing the phrase in double quotes, escaped by %22. For example, if we wanted all of the articles by the author Adrian DelMaestro with titles that contain quantum criticality, we could construct the following query:
|
554 |
+
|
555 |
+
http://export.arxiv.org/api/query?search_query=au:del_maestro+AND+ti:%22quantum+criticality%22
|
556 |
+
|
557 |
+
This query returns one result, and notice that the feed <title> contains double quotes as expected. The table below lists the two grouping operators used in the API.
|
558 |
+
|
559 |
+
symbol encoding explanation
|
560 |
+
( ) %28 %29 Used to group Boolean expressions for Boolean operator precedence.
|
561 |
+
double quotes %22 %22 Used to group multiple words into phrases to search a particular field.
|
562 |
+
space + Used to extend a search_query to include multiple fields.
|
563 |
+
|
564 |
+
5.1.1. A Note on Article Versions
|
565 |
+
Each arXiv article has a version associated with it. The first time an article is posted, it is given a version number of 1. When subsequent corrections are made to an article, it is resubmitted, and the version number is incremented. At any time, any version of an article may be retrieved.
|
566 |
+
|
567 |
+
When using the API, if you want to retrieve the latest version of an article, you may simply enter the arxiv id in the id_list parameter. If you want to retrieve information about a specific version, you can do this by appending vn to the id, where n is the version number you are interested in.
|
568 |
+
|
569 |
+
For example, to retrieve the latest version of cond-mat/0207270, you could use the query http://export.arxiv.org/api/query?id_list=cond-mat/0207270. To retrieve the very first version of this article, you could use the query http://export.arxiv.org/api/query?id_list=cond-mat/0207270v1
|
570 |
+
|
571 |
+
|
572 |
+
5.2. Details of Atom Results Returned
|
573 |
+
The following table lists each element of the returned Atom results. For a more detailed explanation see Outline of an Atom Feed.
|
574 |
+
|
575 |
+
element explanation
|
576 |
+
feed elements
|
577 |
+
<title> The title of the feed containing a canonicalized query string.
|
578 |
+
<id> A unique id assigned to this query.
|
579 |
+
<updated> The last time search results for this query were updated. Set to midnight of the current day.
|
580 |
+
<link> A url that will retrieve this feed via a GET request.
|
581 |
+
<opensearch:totalResults> The total number of search results for this query.
|
582 |
+
<opensearch:startIndex> The 0-based index of the first returned result in the total results list.
|
583 |
+
<opensearch:itemsPerPage> The number of results returned.
|
584 |
+
entry elements
|
585 |
+
<title> The title of the article.
|
586 |
+
<id> A url http://arxiv.org/abs/id
|
587 |
+
<published> The date that version 1 of the article was submitted.
|
588 |
+
<updated> The date that the retrieved version of the article was submitted. Same as <published> if the retrieved version is version 1.
|
589 |
+
<summary> The article abstract.
|
590 |
+
<author> One for each author. Has child element <name> containing the author name.
|
591 |
+
<link> Can be up to 3 given url's associated with this article.
|
592 |
+
<category> The arXiv or ACM or MSC category for an article if present.
|
593 |
+
<arxiv:primary_category> The primary arXiv category.
|
594 |
+
<arxiv:comment> The authors comment if present.
|
595 |
+
<arxiv:affiliation> The author's affiliation included as a subelement of <author> if present.
|
596 |
+
<arxiv:journal_ref> A journal reference if present.
|
597 |
+
<arxiv:doi> A url for the resolved DOI to an external resource if present.
|
598 |
+
5.3. Subject Classifications
|
599 |
+
For the complete list of arXiv subject classifications, please visit the taxonomy page.
|
benchmarks/bibtex-generation/env/bibtex_generation.py
ADDED
File without changes
|
benchmarks/bibtex-generation/env/claude_example.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import anthropic
|
3 |
+
|
4 |
+
client = anthropic.Client(open("claude_api_key.txt").read().strip())
|
5 |
+
response = client.completion(
|
6 |
+
prompt=f"{anthropic.HUMAN_PROMPT} How many toes do dogs have?{anthropic.AI_PROMPT}",
|
7 |
+
stop_sequences = [anthropic.HUMAN_PROMPT],
|
8 |
+
model="claude-v1",
|
9 |
+
max_tokens_to_sample=100,
|
10 |
+
)
|
11 |
+
print(response)
|