The BaselineModel class in baselines.py file is a full working Graph Neural Network (GNN) example using JAX and the DeepMind JAX Ecosystem of libraries. It allows training of multiple algorithms on a single processor, as described in the paper "A Generalist Neural Algorithmic Learner" (arXiv:2209.11142v2 [cs.LG] 3 Dec 2022). Below is an excerpt from the paper that describes the model: Each algorithm in the CLRS benchmark [5] is specified by a number of inputs, hints and outputs. In a given sample, the inputs and outputs are fixed, while hints are time-series of intermediate states of the algorithm. Each sample for a particular task has a size, n, corresponding to the number of nodes in the GNN that will execute the algorithm. A sample of every algorithm is represented as a graph, with each input, output and hint located in either the nodes, the edges, or the graph itself, and therefore has shape (excluding batch dimension, and, for hints, time dimension) n × f , n × n × f , or f , respectively, f being the dimensionality of the feature, which depends on its type. The CLRS benchmark defines five types of features: scalar, categorical, mask, mask_one and pointer, with their own encoding and decoding strategies and loss functions—e.g. a scalar type will be encoded and decoded directly by a single linear layer, and optimised using mean squared error. Base Model Encoder. We adopt the same encode-process-decode paradigm [33] presented with the CLRS benchmark [5]. At each time step, t, of a particular task τ (e.g. insertion sort), the task-based encoder fτ , consisting of a linear encoder for each input and hint, embeds inputs and the current hints as high-dimensional vectors. These embeddings of inputs and hints located in the nodes all have the same dimension and are added together; the same happens with hints and inputs located in edges, and in the graph. In our experiments we use the same dimension, h = 128, for node, edge and graph 3 A Generalist Neural Algorithmic Learner embeddings. Thus, at the step for a time-step t of the algorithm, we have a n end of the encoding o (t) (t) (t) single set of embeddings xi , eij , g , shapes n × h, n × n × h, and h, in the nodes, edges and graph, respectively. Note that this is independent of the number and type of the inputs and hints of the particular algorithm, allowing us to share this latent space across all thirty algorithms in CLRS. Further, note that at each step, the input encoding is fed directly to these embeddings—this recall mechanism significantly improves the model’s robustness over long trajectories [34]. Processor. The embeddings are fed into a processor P , a GNN that performs one step of computation. The processor transforms the input node, edge and graph embeddings into processed (t) node embeddings, hi . Additionally, the processor uses the processed node embeddings from the (t−1) previous step, hi , as inputs. Importantly, the same processor model can operate on graphs of any size. We leverage the message-passing neural network [35, MPNN], using the max aggregation and passing messages over a fully-connected graph, as our base model. The MPNN computes processed embeddings as follows:     (t) (t−1) (t) (t) (t) (t) (t) (t) (t) z(t) = xi khi mi = max fm zi , zj , eij , g(t) hi = fr zi , mi (1) 1≤j≤n starting from h(0) = 0. Here k denotes concatenation, fm : R2h × R2h × Rh × Rh → Rh is the message function (for which we use a three-layer MLP with ReLU activations), and fr : R2h × Rh → Rh is the readout function (for which we use a linear layer with ReLU activation). The use of the max aggregator is well-motivated by prior work [5, 9], and we use the fully connected graph—letting the neighbours j range over all nodes (1 ≤ j ≤ n)—in order to allow the model to overcome situations (t) where the input graph structure may be suboptimal. Layer normalisation [36] is applied to hi before using them further. Further details on the MPNN processor may be found in Veličković et al. [5]. Decoder. The processed embeddings are finally decoded with a task-based decoder gτ , to predict the hints for the next step, and the outputs at the final step. Akin to the encoder, the task-based decoder relies mainly on a linear decoder for each hint and output, along with a mechanism to compute pairwise node similarities when appropriate. Specifically, the pointer type decoder computes a score, sij , for each pair of nodes, and then chooses the pointer of node i by taking either the argmaxj sij or softmaxj sij (depending on whether a hard or soft prediction is used). Loss. The decoded hints and outputs are used to compute the loss during training, according to their type [5]. For each sample in a batch, the hint prediction losses are averaged across hints and time, and the output loss is averaged across outputs (most algorithms have a single output, though some have two outputs). The hint loss and output loss are added together. Besides, the hint predictions at each time step are fed back as inputs for the next step, except possibly at train time if teacher forcing is used (see Section 3.2.1). We train the model on samples with sizes n ≤ 16, and periodically evaluate them on in-distribution samples of size n = 16. Also, periodically, we evaluate the model with the best in-distribution evaluation score so far on OOD samples of size n = 64. In what follows, we will be reporting only these OOD evaluation scores. Full details of the model, training and evaluation hyperparameters can be found in Appendix A. 3.2 Model improvements As previously discussed, single-task improvements, especially in terms of learning stability, will empirically transfer well to multi-task algorithmic learning. We now describe, in a gradual manner, all the changes made to the model, which have lead to an absolute improvement of over 20% on average across all 30 tasks in CLRS. 3.2.1 Dataset and training Removing teacher forcing. At evaluation time, the model has no access to the step-by-step hints in the dataset, and has to rely on its own hint predictions. However, during training, it is sometimes advisable to stabilise the trajectories with teacher forcing [37]—providing the ground-truth hint values instead of the network’s own predictions. In the prior model [5], ground-truth hints were 4 A Generalist Neural Algorithmic Learner provided during training with probability 0.5, as, without teacher forcing, losses tended to grow unbounded along a trajectory when scalar hints were present, destabilising the training. In this work we incorporate several significant stabilising changes (described in future paragraphs), which allows us to remove teacher forcing altogether, aligning training with evaluation, and avoiding the network becoming overconfident in always expecting correct hint predictions. With teacher forcing, performance deteriorates significantly in sorting algorithms and Kruskal’s algorithm. Naïve String Matcher, on the other hand, improves with teacher forcing (see Appendix A, Figs. 7-9). Augmenting the training data. To prevent our model from over-fitting to the statistics of the fixed CLRS training dataset [5], we augmented the training data in three key ways, without breaking the intended size distribution shift. Firstly, we used the on-line samplers in CLRS to generate new training examples on the fly, rather than using a fixed dataset which is easier to overfit to. Secondly, we trained on examples of mixed sizes, n ≤ 16, rather than only 16, which helps the model anticipate for a diverse range of sizes, rather than overfitting to the specifics of size n = 16. Lastly, for graph algorithms, we varied the connectivity probability p of the input graphs (generated by the Erdős-Rényi model [38]); and for string matching algorithms, we varied the length of the pattern to be matched. These both serve to expose the model to different trajectory lengths; for example, in many graph algorithms, the amount of steps the algorithm should run for is related to the graph’s diameter, and varying the connection probability in the graph generation allows for varying the expected diameter. These changes considerably increase training data variability, compared to the original dataset in Veličković et al. [5]. We provide a more detailed step-by-step overview of the data generation process in Appendix A. Soft hint propagation. When predicted hints are fed back as inputs during training, gradients may or may not be allowed to flow through them. In previous work, only hints of the scalar type allowed gradients through, as all categoricals were post-processed from logits into the ground-truth format via argmax or thresholding before being fed back. Instead, in this work we use softmax for categorical, mask_one and pointer types, and the logistic sigmoid for mask types. Without these soft hints, performance in sorting algorithms degrades (similarly to the case of teacher forcing), as well as in Naïve String Matcher (Appendix A, Figs. 7-9). Static hint elimination. Eleven algorithms in CLRS3 specify a fixed ordering of the nodes, common to every sample, via a node pointer hint that does not ever change along the trajectories. Prediction of this hint is trivial (identity function), but poses a potential problem for OOD generalisation, since the model can overfit to the fixed training values. We therefore turned this fixed hint into an input for these 11 algorithms, eliminating the need for explicitly predicting it. Improving training stability with encoder initialisation and gradient clipping. The scalar hints have unbounded values, in principle, and are optimised using mean-squared error, hence their gradients can quickly grow with increasing prediction error. Further, the predicted scalar hints then get re-encoded at every step, which can rapidly amplify errors throughout the trajectory, leading to exploding signals (and consequently gradients), even before any training takes place. To rectify this issue, we use the Xavier initialisation [45], effectively reducing the initial weights for scalar hints whose input dimensionality is just 1. However, we reverted to using the default LeCun initialisation [46] elsewhere. This combination of initialisations proved important for the initial learning stability of our model over long trajectories. Relatedly, in preliminary experiments, we saw drastic improvements in learning stability, as well as significant increases in validation performance, with gradient clipping [47], which we subsequently employed in all experiments. 3.2.2 Encoders and decoders Randomised position scalar. Across all algorithms in the dataset, there exists a position scalar input which uniquely indexes the nodes, with values linearly spaced between 0 and 1 along the node index. To avoid overfitting to these linearly spaced values during training, we replaced them with random values, uniformly sampled in [0, 1], sorted to match the initial order implied by the linearly spaced values. The benefit of this change is notable in algorithms where it would be easy to overfit to 3 Binary Search, Minimum, Max Subarray [39], Matrix Chain Order, LCS Length, Optimal BST [40], Activity Selector [41], Task Scheduling [42], Naïve String Matcher, Knuth-Morris-Pratt [43] and Jarvis’ March [44]. 5 A Generalist Neural Algorithmic Learner these positions, such as string matching. Namely, the model could learn to base all of its computations on the assumption that it will always be finding a m-character pattern inside an n-character string, even though at test time, m and n will increase fourfold. Permutation decoders and the Sinkhorn operator. Sorting algorithms (Insertion Sort, Bubble Sort, Heapsort [48] and Quicksort [49]) always output a permutation of the input nodes. In the CLRS benchmark, this permutation is encoded as a pointer where each node points to its predecessor in the sorted order (the first node points to itself); this is represented as a n × n matrix P where each row is a one-hot vector, such that element (i, j) is 1 if node i points to node j. As with all types of pointers, such permutation pointers can be predicted using a row-wise softmax on unconstrained decoder outputs (logits), trained with cross entropy (as in Veličković et al. [5]). However, this does not explicitly take advantage of the fact that the pointers encode a permutation, which the model has to learn instead. Our early experiments showed that the model was often failing to predict valid permutations OOD. Accordingly, we enforce a permutation inductive bias in the output decoder of sorting algorithms, as follows. First, we modify the output representation by rewiring the first node to point to the last one, turning P into a permutation matrix, i.e., a matrix whose rows and columns are one-hot vectors. We also augment the representation with a one-hot vector of size n that specifies the first node, so we do not lose this information; this vector is treated like a regular mask_one feature. Second, we predict the permutation matrix P from unconstrained decoder outputs Y by replacing the usual row-wise softmax with the Sinkhorn operator S [32, 50–53]. S projects an arbitrary square matrix Y into a doubly stochastic matrix S(Y) (a non-negative matrix whose rows and columns sum to 1), by exponentiating and repeatedly normalizing rows and columns so they sum to 1. Specifically, S is defined by: S 0 (Y) = exp(Y) S l (Y) = Tc (Tr (S l−1 (Y))) S(Y) = lim S l (Y), l→∞ (2) where exp acts element-wise, and Tr and Tc denote row and column normalisation respectively. Although the Sinkhorn operator produces a doubly stochastic matrix rather than a permutation matrix, we can obtain a permutation matrix by introducing a temperature parameter, τ > 0, and taking P = limτ →0+ S(Y/τ ); as long as there are no ties in the elements of Y, P is guaranteed to be a permutation matrix [52, Theorem 1]. In practice, we compute the Sinkhorn operator using a fixed number of iterations lmax . We use a smaller number of iterations lmax = 10 for training, to limit vanishing and exploding gradients, and lmax = 60 for evaluation. A fixed temperature τ = 0.1 was experimentally found to give a good balance between speed of convergence and tie-breaking. We also encode the fact that no node points to itself, that is, that all diagonal elements of P should be 0, by setting the diagonal elements of Y to −∞. To avoid ties, we follow Mena et al. [53], injecting Gumbel noise to the elements of Y prior to applying the Sinkhorn operator, during training only. Finally, we transform the predicted matrix P, and mask_one pointing to the first element, into the original pointer representation used by CLRS. 3.2.3 Processor networks Gating mechanisms. Many algorithms only require updating a few nodes at each time step, keeping the rest unchanged. However, the MPNN we use (Equation 1) is biased towards the opposite: it updates all hidden states in each step. Although it is theoretically possible for the network to keep the states unchanged, learning to do so is not easy. With this in mind, and motivated by its effectiveness in NDRs [54], we augment the network with an update gate, biased to be closed by default. We found that the gate stabilizes learning on many of the tasks, and increases the mean performance over all tasks on single-task training significantly. Surprisingly, however, we did not find gating to be advantageous in the multi-task case. To add gating to the MPNN model we produce a per-node gating vector from the same inputs that process the embeddings in Equation 1:   (t) (t) (t) gi = fg zi , mi (3) where fg : R2h × Rh → Rh is the gating function, for which we use a two-layer MLP, with ReLU activation for the hidden layer and logistic sigmoid activation for the output. Importantly, the final layer bias of fg is initialized to a value of −3, which biases the network for not updating its 6 A Generalist Neural Algorithmic Learner Our model Previous SOTA [5] 80 60 40 Quickselect Heapsort Knuth-Morris-Pratt Strongly Conn. Comps. DFS Floyd-Warshall Quicksort Bubble Sort Optimal BST Find Max. Subarray Insertion Sort Binary Search LCS Length Naïve String Matcher MST Prim Topological Sort Task Scheduling MST Kruskal Articulation Points Jarvis' March Matrix Chain Order Bridges Graham Scan Dijkstra Activity Selector Bellman-Ford DAG Shortest Paths Segments Intersect 0 BFS 20 Minimum Average score [%] 100 Figure 2: The OOD performance in single-task experiments before and after the improvements presented in this paper, sorted in descending order of current performance. Error bars represent standard error of the mean across seeds (3 seeds for previous SOTA experiments, 10 seeds for current). The previous SOTA values are the best of MPNN, PGN and Memnet models (see Table 2). b (t) , are computed as follows: representations, unless necessary. The processed gated embeddings, h i b (t) = g(t) h i i and are used instead of (t) hi (t) (t) hi + (1 − gi ) in the subsequent steps, replacing z (t−1) hi (t) (4) in Eq. 1 by z (t) = (t) b (t−1) xi kh . i Triplet reasoning. Several algorithms within CLRS-30 explicitly require edge-based reasoning— where edges store values, and update them based on other edges’ values. An example of this is the Floyd-Warshall algorithm [55], which computes all-pairs shortest paths in a weighted graph. The update rule for dij , its estimate for the best distance from node i to j, is dij = mink dik + dkj , which roughly says “the best way to get from i to j is to find the optimal mid-point k, travel from i to k, then from k to j”. Similar rules are pervasive across many CLRS-30 algorithms, especially in dynamic programming. Even though there are no node representations in the above update, all our processors are centered on passing messages between node representations hi . To rectify this situation, we augment our processor to perform message passing towards edges. Referring again to the update for dij , we note that the edge representations are updated by choosing an intermediate node, then aggregating over all possible choices. Accordingly, and as previously observed by Dudzik and Veličković [31], we introduce triplet reasoning: first, computing representations over triplets of nodes, then reducing over one node to obtain edge latents: tijk = ψt (hi , hj , hk , eij , eik , ekj , g) hij = φt (max tijk ) (5) k Here, ψt is a triplet message function, mapping all relevant representations to a single vector for each triplet of nodes, and φt is an edge readout function, which transforms the aggregated triplets for each edge for later use. According to prior findings on the CLRS benchmark [5], we use the max aggregation to obtain edge representations. The computed hij vectors can then be used in any edge-based reasoning task, and empirically they are indeed significantly beneficial, even in tasks where we did not initially anticipate such benefits. One example is Kruskal’s minimum spanning tree algorithm [56], where we presume that access to triplet reasoning allowed the model to more easily sort the edges by weight, as it selects how to augment the spanning forest at each step. In order to keep the footprint of triplet embeddings as lightweight as possible, we compute only 8-dimensional features in ψt . φt then upscales the aggregated edge features back to 128 dimensions, to make them compatible with the rest of the architecture. Our initial experimentation demonstrated that the output dimensionality of ψt did not significantly affect downstream performance. Note that computing triplet representations has been a useful approach in general GNN design [57]—however, it has predominantly been studied in the context of GNNs over constant input features. Our study is among the first to verify their utility over reasoning tasks with well-specified initial features. 3.3 Results By incorporating the changes described in the previous sections we arrived at a single model type, with a single set of hyper-parameters, that was trained to reach new state-of-the-art performance 7 A Generalist Neural Algorithmic Learner Table 1: Single-task OOD micro-F1 score of previous SOTA Memnet, MPNN and PGN [5] and our best model Triplet-GMPNN with all our improvements, after 10,000 training steps. Alg. Type Memnet [5] MPNN [5] PGN [5] Triplet-GMPNN (ours) Div. & C. DP Geometry Graphs Greedy Search Sorting Strings 13.05% ± 0.14 67.94% ± 8.20 45.14% ± 11.95 24.12% ± 5.30 53.42% ± 20.82 34.35% ± 21.67 71.53% ± 1.41 1.51% ± 0.46 20.30% ± 0.85 65.10% ± 6.44 73.11% ± 17.19 62.79% ± 8.75 82.39% ± 3.01 41.20% ± 19.87 11.83% ± 2.78 3.21% ± 0.94 65.23% ± 4.44 70.58% ± 6.48 61.19% ± 7.01 60.25% ± 8.42 75.84% ± 6.59 56.11% ± 21.56 15.45% ± 8.46 2.04% ± 0.20 76.36% ± 1.34 81.99% ± 4.98 94.09% ± 2.30 81.41% ± 6.21 91.21% ± 2.95 58.61% ± 24.34 60.37% ± 12.16 49.09% ± 23.49 38.88% 44.99% 50.84% 74.14% 0/30 3/30 10/30 6/30 9/30 14/30 3/30 7/30 15/30 11/30 17/30 24/30 Overall avg. > 90% > 80% > 60% on CLRS-30 [5]. Tables 1 and 2 show the micro-F1 scores of our model, which we refer to as Triplet-GMPNN (an MPNN with gating and triplet edge processing), over the original CLRS-30 test set (computed identically to Veličković et al. [5], but with 10 repetitions instead of 3). Our baselines include the Memnet [58], MPNN [35] and PGN [59] models, taken directly from Veličković et al. [5]. Figure 2 displays the comparison between the improved model and the best model from Veličković et al. [5]. Our improvements lead to an overall average performance that is more than 20% higher (in absolute terms) compared to the next best model (see Table 1), and to a significant performance improvement in all but one algorithm family, compared to every other model. Further, our stabilising changes (such as gradient clipping) have empirically reduced the scale of our model’s gradient updates across the 30 tasks, preparing us better for the numerical issues of the multi-task regime. We finally also note that though we do not show it in Tables 1 & 2, applying the same improvements to the PGN processor, leads to an increase in overall performance from 50.84% (Table 1) to 69.31%. There are two notable examples of algorithm families with significant OOD performance improvement. The first are geometric algorithms (Segments Intersect, Graham Scan [60] and Jarvis’ March), now solved at approximately 94% OOD, compared to the previous best of about 73%; the second being string algorithms (Knuth-Morris-Pratt and Naïve String Matcher) for which our model now exceeds 49% compared to the previous best of approximately 3%. The significant overall performance boost is reflected in the increased number of algorithms we can now solve at over 60%, 80% & 90% OOD performance, compared to previous SOTA [5]. Specifically, we now exceed 60% accuracy in 24 algorithms (15 algorithms previously), 80% for 17 algorithms (9 previously) and 90% for 11 algorithms (6 previously).