{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "qf-uHjOnuw5g"
},
"source": [
"# Node Classification with Graph Neural Networks\n",
"\n",
"**Author:** [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
\n",
"**Date created:** 2021/05/30
\n",
"**Last modified:** 2021/05/30
\n",
"**Description:** Implementing a graph neural network model for predicting the topic of a paper given its citations."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "THU5mq3Buw5i"
},
"source": [
"## Introduction\n",
"\n",
"Many datasets in various machine learning (ML) applications have structural relationships\n",
"between their entities, which can be represented as graphs. Such application includes\n",
"social and communication networks analysis, traffic prediction, and fraud detection.\n",
"[Graph representation Learning](https://www.cs.mcgill.ca/~wlh/grl_book/)\n",
"aims to build and train models for graph datasets to be used for a variety of ML tasks.\n",
"\n",
"This example demonstrate a simple implementation of a [Graph Neural Network](https://arxiv.org/pdf/1901.00596.pdf)\n",
"(GNN) model. The model is used for a node prediction task on the [Cora dataset](https://relational.fit.cvut.cz/dataset/CORA)\n",
"to predict the subject of a paper given its words and citations network.\n",
"\n",
"Note that, **we implement a Graph Convolution Layer from scratch** to provide better\n",
"understanding of how they work. However, there is a number of specialized TensorFlow-based\n",
"libraries that provide rich GNN APIs, such as [Spectral](https://graphneural.network/),\n",
"[StellarGraph](https://stellargraph.readthedocs.io/en/stable/README.html), and\n",
"[GraphNets](https://github.com/deepmind/graph_nets)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "RK6CHiyAuw5j"
},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "cCWyYWzLuw5j"
},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"import numpy as np\n",
"import networkx as nx\n",
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf\n",
"from tensorflow import keras\n",
"from tensorflow.keras import layers"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cTFeuvsYuw5j"
},
"source": [
"## Prepare the Dataset\n",
"\n",
"The Cora dataset consists of 2,708 scientific papers classified into one of seven classes.\n",
"The citation network consists of 5,429 links. Each paper has a binary word vector of size\n",
"1,433, indicating the presence of a corresponding word.\n",
"\n",
"### Download the dataset\n",
"\n",
"The dataset has two tap-separated files: `cora.cites` and `cora.content`.\n",
"\n",
"1. The `cora.cites` includes the citation records with two columns:\n",
"`cited_paper_id` (target) and `citing_paper_id` (source).\n",
"2. The `cora.content` includes the paper content records with 1,435 columns:\n",
"`paper_id`, `subject`, and 1,433 binary features.\n",
"\n",
"Let's download the dataset."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "7OHK8dAguw5k",
"outputId": "1428c9be-d01e-40b5-8091-fc5dc5698c12",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Downloading data from https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz\n",
"172032/168052 [==============================] - 0s 2us/step\n",
"180224/168052 [================================] - 0s 2us/step\n"
]
}
],
"source": [
"zip_file = keras.utils.get_file(\n",
" fname=\"cora.tgz\",\n",
" origin=\"https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz\",\n",
" extract=True,\n",
")\n",
"data_dir = os.path.join(os.path.dirname(zip_file), \"cora\")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "n3H5HwUsuw5k"
},
"source": [
"### Process and visualize the dataset\n",
"\n",
"Then we load the citations data into a Pandas DataFrame."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "wloojBfEuw5l",
"outputId": "6260336f-fea2-47dd-e653-a55cdf995218",
"colab": {
"base_uri": "https://localhost:8080/"
}
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Citations shape: (5429, 2)\n"
]
}
],
"source": [
"citations = pd.read_csv(\n",
" os.path.join(data_dir, \"cora.cites\"),\n",
" sep=\"\\t\",\n",
" header=None,\n",
" names=[\"target\", \"source\"],\n",
")\n",
"print(\"Citations shape:\", citations.shape)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lRztKFR8uw5l"
},
"source": [
"Now we display a sample of the `citations` DataFrame.\n",
"The `target` column includes the paper ids cited by the paper ids in the `source` column."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "6-Kdix_1uw5l",
"outputId": "639e4995-89b1-4d52-cd42-9c9718a0c8d8",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 207
}
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/html": [
"
\n", " | target | \n", "source | \n", "
---|---|---|
1393 | \n", "6741 | \n", "51909 | \n", "
5347 | \n", "696345 | \n", "696342 | \n", "
444 | \n", "1365 | \n", "26850 | \n", "
5312 | \n", "671269 | \n", "1154124 | \n", "
4945 | \n", "523574 | \n", "653441 | \n", "