# Copyright (c) Meta Platforms, Inc. and affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import torch def flatten_graph(node_embeddings, edge_embeddings, edge_index): """ Flattens the graph into a batch size one (with disconnected subgraphs for each example) to be compatible with pytorch-geometric package. Args: node_embeddings: node embeddings in tuple form (scalar, vector) - scalar: shape batch size x nodes x node_embed_dim - vector: shape batch size x nodes x node_embed_dim x 3 edge_embeddings: edge embeddings of in tuple form (scalar, vector) - scalar: shape batch size x edges x edge_embed_dim - vector: shape batch size x edges x edge_embed_dim x 3 edge_index: shape batch_size x 2 (source node and target node) x edges Returns: node_embeddings: node embeddings in tuple form (scalar, vector) - scalar: shape batch total_nodes x node_embed_dim - vector: shape batch total_nodes x node_embed_dim x 3 edge_embeddings: edge embeddings of in tuple form (scalar, vector) - scalar: shape batch total_edges x edge_embed_dim - vector: shape batch total_edges x edge_embed_dim x 3 edge_index: shape 2 x total_edges """ x_s, x_v = node_embeddings e_s, e_v = edge_embeddings batch_size, N = x_s.shape[0], x_s.shape[1] node_embeddings = (torch.flatten(x_s, 0, 1), torch.flatten(x_v, 0, 1)) edge_embeddings = (torch.flatten(e_s, 0, 1), torch.flatten(e_v, 0, 1)) edge_mask = torch.any(edge_index != -1, dim=1) # Re-number the nodes by adding batch_idx * N to each batch edge_index = edge_index + (torch.arange(batch_size, device=edge_index.device) * N).unsqueeze(-1).unsqueeze(-1) edge_index = edge_index.permute(1, 0, 2).flatten(1, 2) edge_mask = edge_mask.flatten() edge_index = edge_index[:, edge_mask] edge_embeddings = ( edge_embeddings[0][edge_mask, :], edge_embeddings[1][edge_mask, :] ) return node_embeddings, edge_embeddings, edge_index def unflatten_graph(node_embeddings, batch_size): """ Unflattens node embeddings. Args: node_embeddings: node embeddings in tuple form (scalar, vector) - scalar: shape batch total_nodes x node_embed_dim - vector: shape batch total_nodes x node_embed_dim x 3 batch_size: int Returns: node_embeddings: node embeddings in tuple form (scalar, vector) - scalar: shape batch size x nodes x node_embed_dim - vector: shape batch size x nodes x node_embed_dim x 3 """ x_s, x_v = node_embeddings x_s = x_s.reshape(batch_size, -1, x_s.shape[1]) x_v = x_v.reshape(batch_size, -1, x_v.shape[1], x_v.shape[2]) return (x_s, x_v)