Graph Neural Network

This notebook illustrates how to perform node classification in a graph using a graph neural network.

[1]:
from IPython.display import SVG
[2]:
import numpy as np
from scipy import sparse
[3]:
from sknetwork.data import art_philo_science
from sknetwork.classification import get_accuracy_score
from sknetwork.gnn import GNNClassifier
from sknetwork.visualization import visualize_graph

Graph with features

Let’s load the art_philo_science toy dataset. It consists of a selection of 30 Wikipedia articles with links between them. Each article is described by some words used in their summary, among a list of 11 words. Each article belongs to one of the following 3 categories: arts, philosophy or science.

The goal is to retrieve the category of some articles (the test set) from the category of the other articles (the train set).

[4]:
graph = art_philo_science(metadata=True)
adjacency = graph.adjacency
features = graph.biadjacency
names = graph.names
names_features = graph.names_col
names_labels = graph.names_labels
labels_true = graph.labels
position = graph.position
[5]:
adjacency
[5]:
<30x30 sparse matrix of type '<class 'numpy.bool_'>'
        with 240 stored elements in Compressed Sparse Row format>
[6]:
print(names)
['Isaac Newton' 'Albert Einstein' 'Carl Linnaeus' 'Charles Darwin'
 'Ptolemy' 'Gottfried Wilhelm Leibniz' 'Carl Friedrich Gauss'
 'Galileo Galilei' 'Leonhard Euler' 'John von Neumann' 'Leonardo da Vinci'
 'Richard Wagner' 'Ludwig van Beethoven' 'Bob Dylan' 'Igor Stravinsky'
 'The Beatles' 'Wolfgang Amadeus Mozart' 'Richard Strauss' 'Raphael'
 'Pablo Picasso' 'Aristotle' 'Plato' 'Augustine of Hippo' 'Thomas Aquinas'
 'Immanuel Kant' 'Bertrand Russell' 'David Hume' 'René Descartes'
 'John Stuart Mill' 'Socrates']
[7]:
print(len(names))
30
[8]:
features
[8]:
<30x11 sparse matrix of type '<class 'numpy.int64'>'
        with 101 stored elements in Compressed Sparse Row format>
[9]:
print(names_features)
['contribution' 'theory' 'invention' 'time' 'modern' 'century' 'study'
 'logic' 'school' 'author' 'compose']
[10]:
len(names_features)
[10]:
11
[11]:
print(names_labels)
['science' 'arts' 'philosophy']
[12]:
# Number of labels
n_labels = len(set(labels_true))

GCN

The default GNN is a spatial graph convolutional network (GCN). We here use a single hidden layer. More hidden layers can be specified through the parameter dims (list of dimensions).

[13]:
# GNN classifier with a single hidden layer
hidden_dim = 5

gnn = GNNClassifier(dims=[hidden_dim, n_labels],
                    layer_types='Conv',
                    activations='ReLu',
                    verbose=True)
[14]:
print(gnn)
GNNClassifier(
    Convolution(layer_type: conv, out_channels: 5, activation: ReLu, use_bias: True, normalization: both, self_embeddings: True)
    Convolution(layer_type: conv, out_channels: 3, activation: Cross entropy, use_bias: True, normalization: both, self_embeddings: True)
)
[15]:
# Training set
labels = labels_true.copy()
np.random.seed(42)
train_mask = np.random.random(size=len(labels)) < 0.5
labels[train_mask] = -1
[16]:
# Training
labels_pred = gnn.fit_predict(adjacency, features, labels, n_epochs=200, random_state=42)
In epoch   0, loss: 1.053, train accuracy: 0.462
In epoch  20, loss: 0.834, train accuracy: 0.692
In epoch  40, loss: 0.819, train accuracy: 0.692
In epoch  60, loss: 0.831, train accuracy: 0.692

In epoch  80, loss: 0.839, train accuracy: 0.692
In epoch 100, loss: 0.839, train accuracy: 0.692
In epoch 120, loss: 0.825, train accuracy: 0.692
In epoch 140, loss: 0.771, train accuracy: 0.769
In epoch 160, loss: 0.557, train accuracy: 1.000
In epoch 180, loss: 0.552, train accuracy: 1.000
[17]:
# History for each training epoch
gnn.history_.keys()
[17]:
dict_keys(['loss', 'train_accuracy'])
[18]:
# Accuracy on test set
test_mask = ~train_mask
get_accuracy_score(labels_true[test_mask], labels_pred[test_mask])
[18]:
1.0
[19]:
# Visualization
image = visualize_graph(adjacency, position=position, names=names, labels=labels_pred)
SVG(image)
[19]:
../../_images/tutorials_gnn_gnn_classifier_24_0.svg
[20]:
# probability distribution over labels
probs = gnn.predict_proba()
[21]:
label = 1
scores = probs[:, label]
[22]:
# Visualization
image = visualize_graph(adjacency, position=position, names=names, scores=scores)
SVG(image)
[22]:
../../_images/tutorials_gnn_gnn_classifier_27_0.svg

GraphSAGE

Another available GNN is GraphSAGE.

[23]:
# GraphSAGE layers
gnn = GNNClassifier(dims=[5, 3], layer_types='Sage')
[24]:
print(gnn)
GNNClassifier(
    Convolution(layer_type: sage, out_channels: 5, activation: ReLu, use_bias: True, normalization: left, self_embeddings: True, sample_size: 25)
    Convolution(layer_type: sage, out_channels: 3, activation: Cross entropy, use_bias: True, normalization: left, self_embeddings: True, sample_size: 25)
)
[25]:
# Training
labels_pred = gnn.fit_predict(adjacency, features, labels, n_epochs=200, random_state=42)
[26]:
# Accuracy on test set
test_mask = ~train_mask
get_accuracy_score(labels_true[test_mask], labels_pred[test_mask])
[26]:
1.0
[27]:
# Parameters of the GNN
weights = [layer.weight for layer in gnn.layers]
biases = [layer.bias for layer in gnn.layers]
[28]:
[weight.shape for weight in weights]
[28]:
[(11, 5), (5, 3)]
[29]:
[bias.shape for bias in biases]

[29]:
[(1, 5), (1, 3)]
[30]:
# probability distribution over labels
probs = gnn.predict_proba()
[31]:
label = 1
scores = probs[:, label]
[32]:
# Visualization
image = visualize_graph(adjacency, position=position, names=names, scores=scores)
SVG(image)

[32]:
../../_images/tutorials_gnn_gnn_classifier_39_0.svg
[33]:
# Parameters of the GNN
weights = [layer.weight for layer in gnn.layers]
biases = [layer.bias for layer in gnn.layers]
[34]:
[weight.shape for weight in weights]
[34]:
[(11, 5), (5, 3)]
[35]:
[bias.shape for bias in biases]

[35]:
[(1, 5), (1, 3)]
[36]:
# probability distribution over labels
probs = gnn.predict_proba()
[37]:
label = 1
scores = probs[:, label]
[38]:
# Visualization
image = visualize_graph(adjacency, position=position, names=names, scores=scores)
SVG(image)

[38]:
../../_images/tutorials_gnn_gnn_classifier_45_0.svg