teachyourselfcoding's picture
Upload 245 files
fa6856c
raw
history blame
8.58 kB
from typing import Callable, Dict, List, Optional, Tuple
import networkx as nx
import numpy as np
import torch
def generate_rand_int_excluding(rng: np.random.RandomState, max: int, exclude: int) -> int:
"""Random integer generator, excluding a specific number
Args:
rng: Numpy random number generator
max: Max number
exclude: Number to exclude
Returns:
Random integer in [0, max], excluding the `exclude` integer.
"""
while True:
# Create the random integer
x = rng.randint(max)
# Return the random integer if it isn't the exclude value, otherwise try
# again
if x != exclude:
return x
def generate_random_walks( # noqa: max-complexity
n_nodes: int = 21,
max_length: int = 10,
n_walks: int = 1000,
p_edge: float = 0.1,
seed: int = 1002,
gpt2_tokenizer: bool = False,
) -> Tuple[Callable[[List[str]], Dict[str, List[float]]], List[str], List[str], torch.Tensor,]:
"""Generate random walks
Args:
n_nodes: Number of nodes. This should not be more than 26, as we use
single letters to represent each node.
max_length: Maximum number of steps in each random walk
n_walks: Number of random walks (samples) to create
p_edge: Probability that any source node connects to any other
destination node
seed: Random seed
gpt2_tokenizer: True if GPT2's tokenizer is being used
Returns:
Tuple of metric function,
"""
# Initialise a random state with the seed
rng = np.random.RandomState(seed)
# Create the adjacency matrix
# https://en.wikipedia.org/wiki/Adjacency_matrix
# This is a 2d matrix, where the rows represent the source nodes and the
# columns represent the destination nodes. If a cell (i,j) is True, then
# there is a directional edge from the source node (i) to the destination
# node (j). If it is false there is no connection.
while True:
# Create the adjacency matrix, where each node is connected to each
# other node, with probability p_edge
adjacency_matrix: np.ndarray = rng.rand(n_nodes, n_nodes) > (1 - p_edge)
# Nodes can't be connected to themselves, so the diagonal values must
# all be False
np.fill_diagonal(adjacency_matrix, 0)
# Each destination node (column) must be connected to at least one
# source node. This checks if this is the case, by checking there is a
# True value in every column. If it is not the case, we try to generate
# a new adjacency matrix again from scratch (in the while loop).
if np.all(adjacency_matrix.sum(1)):
break
# Set the goal node as 0
goal: int = 0
# The goal node is the terminal state, so we make sure that it doesn't
# have a directional edge going to any other nodes (i.e. it can only be
# connected to from previous nodes). We also set the connection to itself as
# True.
adjacency_matrix[goal, :] = 0
adjacency_matrix[goal, goal] = 1
# Create dicts for converting nodes into characters and vice versa
# Nodes are converted into characters as these (when split by the delimiter) are
# guaranteed to be tokenized as individual tokens.
char_to_node: Dict[str, int] = {chr(ix + ord("a")): ix for ix in range(n_nodes)}
node_to_char: Dict[int, str] = {ix: chr(ix + ord("a")) for ix in range(n_nodes)}
# Initialise a list of sample walks
sample_walks: List[str] = []
# String delimiter (to force the tokenizer to keep all nodes as separate
# tokens)
delimiter: str = "|" if gpt2_tokenizer else ""
# Create n_walks samples
for _ in range(n_walks):
# Create a random starting node (that isn't already at the goal state)
node: int = generate_rand_int_excluding(rng, n_nodes, goal)
# Initialise the list of nodes that we visit
walk_nodes: List[int] = [node]
# Do a series of steps, until we hit the maximum number of steps or the
# goal state (whichever comes first)
for _step in range(max_length - 1):
# From the starting node, get all the nodes we can move to. Pick one
# of these at random, and add it to the list of visited nodes
node = rng.choice(np.nonzero(adjacency_matrix[node])[0])
walk_nodes.append(node)
# If we're at the goal state, stop
if node == goal:
break
# Convert the nodes visited to letters (not integers)
walk: List[str] = [node_to_char[ix] for ix in walk_nodes]
# Concatenate into a journey, with each node letter separated by the
# delimiter.
sample_walks.append(delimiter.join(walk))
# Initialise list of shortest lengths for each node (to the goal node)
shortest_lengths: List[int] = []
# Create a directional graph from the adjacency list
directional_graph = nx.from_numpy_array(adjacency_matrix, create_using=nx.DiGraph)
# Fore each node (except for the goal node), find the shortest path
for start in set(range(n_nodes)) - {goal}:
try:
# Find the shortest path (up to the max_length)
shortest_path = nx.shortest_path(directional_graph, start, goal)[:max_length]
shortest_lengths.append(len(shortest_path))
except Exception:
# If there is no path, use the maximum length instead
shortest_lengths.append(max_length)
def metric_fn(
samples: List[str],
) -> Dict[str, List[float]]:
"""Metric Function
Args:
samples: Batch of samples
Returns:
Dict of metrics, each with a key of the metric name and value as a
list of metric values for each batch item.
"""
# Length to set if the path is invalid
invalid_path_length: int = 100
# Initialise batch lengths & reference lengths (the optimal length
# starting from each batch items specific start node)
lengths: List[float] = []
sample_optimal_lengths: List[int] = []
for sample_str in samples:
# Remove GPT2 specific tokenizer delimiter
if gpt2_tokenizer:
sample_str = sample_str.replace("|", "")
# Convert the sample into a list of nodes (default to an unused
# integer if the node is not found)
sample: List[int] = [char_to_node.get(c, 1000) for c in sample_str]
# Initialise the specific sample length
length: Optional[float] = None
for node in range(len(sample)):
# If an invalid path is taken, set the length to the invalid
# path score
if sample[node] >= n_nodes or node > 0 and not adjacency_matrix[sample[node - 1], sample[node]]:
length = invalid_path_length
break
# Otherwise increment the length for each move (where we don't
# end up at the goal node)
elif sample[node] == 0:
length = node + 1
break
# Catch the case where there are no moves
if length is None:
length = invalid_path_length
# Store the batch item length & optimal length staring from the
# start node
lengths.append(float(length))
sample_optimal_lengths.append(shortest_lengths[sample[0] - 1])
# Calculate optimality scores, in [0, 1], as compared to the shortest
# path
lengths_tensor = torch.tensor(lengths, dtype=torch.float)
bound_lengths: torch.Tensor = torch.where(
lengths_tensor.eq(invalid_path_length), max_length, lengths_tensor
).abs()
optimal_lengths = torch.as_tensor(sample_optimal_lengths)
# Optimality scores, in [0, 1], as compared to the shortest path
optimality = (max_length - bound_lengths) / (max_length - optimal_lengths)
return {
"lengths": lengths,
"optimality": optimality.tolist(),
}
logit_mask = torch.tensor(adjacency_matrix)
# Set the evaluation prompts as a list of unique random walk samples, using
# just the start point (first character) from each samples.
eval_prompts = list(sorted(set(w[0] for w in sample_walks)))
eval_prompts = [prompt + delimiter for prompt in eval_prompts]
return (metric_fn, eval_prompts, sample_walks, logit_mask)