Graph Classification with Graph Neural Networks


Adapted from PyG Tutorial

In [6]:
# Install required packages.
!pip install -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.10.0+cu113.html
!pip install -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.10.0+cu113.html
!pip install -q git+https://github.com/rusty1s/pytorch_geometric.git

Graph Classification with Graph Neural Networks

We will have a closer look at how to apply Graph Neural Networks (GNNs) to the task of graph classification. Graph classification refers to the problem of classifiying entire graphs according to some type of labels.

The most common task for graph classification is molecular property prediction, in which molecules are represented as graphs, and the task may be to infer whether a molecule inhibits HIV virus replication or not.

Dataset

TUDatasets has a wide range of different graph classification datasets. Let's load and inspect one of the smaller ones, the MUTAG dataset

=== Description of the dataset ===

The MUTAG dataset consists of 188 chemical compounds divided into two classes according to their mutagenic effect on a bacterium.

The chemical data was obtained form http://cdb.ics.uci.edu and converted to graphs, where vertices represent atoms and edges represent chemical bonds. Explicit hydrogen atoms have been removed and vertices are labeled by atom type and edges by bond type (single, double, triple or aromatic). Chemical data was processed using the Chemistry Development Kit (v1.4).

In [7]:
import torch
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='data/TUDataset', name='MUTAG')

print()
print(f'Dataset: {dataset}:')
print('====================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

data = dataset[0]  # Get the first graph object.

print()
print(data)
print('=============================================================')
Dataset: MUTAG(188):
====================
Number of graphs: 188
Number of features: 7
Number of classes: 2

Data(edge_index=[2, 38], x=[17, 7], edge_attr=[38, 4], y=[1])
=============================================================
In [8]:
# Gather some statistics about the first graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
Number of nodes: 17
Number of edges: 38
Average node degree: 2.24
Has isolated nodes: False
Has self-loops: False
Is undirected: True

This dataset provides 188 different graphs, and the task is to classify each graph into one out of two classes.

The first graph object of the dataset comes with 17 nodes (with 7-dimensional feature vectors) and 38 edges (leading to an average node degree of 2.24).

In [17]:
from torch_geometric.utils import to_networkx
import networkx as nx
import matplotlib.pyplot as plt

def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     cmap="Set2")
    plt.show()

G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)
In [19]:
from torch_geometric.loader import DataLoader

torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:150]
test_dataset = dataset[150:]

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
In [21]:
from torch_geometric.nn import GraphConv
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import global_mean_pool

class GNN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GNN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GraphConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GraphConv(hidden_channels, hidden_channels)
        self.conv3 = GraphConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        x = global_mean_pool(x, batch)

        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GNN(hidden_channels=64)
print(model)
GNN(
  (conv1): GraphConv(7, 64)
  (conv2): GraphConv(64, 64)
  (conv3): GraphConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)
In [23]:
from IPython.display import Javascript
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 171):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')
Epoch: 001, Train Acc: 0.7333, Test Acc: 0.7895
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 005, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 006, Train Acc: 0.6533, Test Acc: 0.7368
Epoch: 007, Train Acc: 0.7333, Test Acc: 0.8158
Epoch: 008, Train Acc: 0.7267, Test Acc: 0.8158
Epoch: 009, Train Acc: 0.7867, Test Acc: 0.8421
Epoch: 010, Train Acc: 0.7733, Test Acc: 0.8158
Epoch: 011, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 012, Train Acc: 0.7933, Test Acc: 0.8421
Epoch: 013, Train Acc: 0.7733, Test Acc: 0.8421
Epoch: 014, Train Acc: 0.7733, Test Acc: 0.7895
Epoch: 015, Train Acc: 0.7933, Test Acc: 0.8421
Epoch: 016, Train Acc: 0.7667, Test Acc: 0.7632
Epoch: 017, Train Acc: 0.7933, Test Acc: 0.8421
Epoch: 018, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 019, Train Acc: 0.7867, Test Acc: 0.7895
Epoch: 020, Train Acc: 0.8133, Test Acc: 0.8421
Epoch: 021, Train Acc: 0.8000, Test Acc: 0.7632
Epoch: 022, Train Acc: 0.7933, Test Acc: 0.8421
Epoch: 023, Train Acc: 0.8133, Test Acc: 0.8421
Epoch: 024, Train Acc: 0.8667, Test Acc: 0.7368
Epoch: 025, Train Acc: 0.8467, Test Acc: 0.7632
Epoch: 026, Train Acc: 0.8400, Test Acc: 0.7368
Epoch: 027, Train Acc: 0.8400, Test Acc: 0.7632
Epoch: 028, Train Acc: 0.8133, Test Acc: 0.8421
Epoch: 029, Train Acc: 0.9067, Test Acc: 0.7632
Epoch: 030, Train Acc: 0.8800, Test Acc: 0.7895
Epoch: 031, Train Acc: 0.8600, Test Acc: 0.7632
Epoch: 032, Train Acc: 0.9133, Test Acc: 0.7895
Epoch: 033, Train Acc: 0.9267, Test Acc: 0.8158
Epoch: 034, Train Acc: 0.8933, Test Acc: 0.8158
Epoch: 035, Train Acc: 0.9200, Test Acc: 0.8684
Epoch: 036, Train Acc: 0.9000, Test Acc: 0.7368
Epoch: 037, Train Acc: 0.9267, Test Acc: 0.8684
Epoch: 038, Train Acc: 0.9400, Test Acc: 0.8684
Epoch: 039, Train Acc: 0.9133, Test Acc: 0.7632
Epoch: 040, Train Acc: 0.9267, Test Acc: 0.8158
Epoch: 041, Train Acc: 0.9133, Test Acc: 0.7368
Epoch: 042, Train Acc: 0.9067, Test Acc: 0.8421
Epoch: 043, Train Acc: 0.9133, Test Acc: 0.8421
Epoch: 044, Train Acc: 0.9267, Test Acc: 0.8684
Epoch: 045, Train Acc: 0.9067, Test Acc: 0.8158
Epoch: 046, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 047, Train Acc: 0.9200, Test Acc: 0.8158
Epoch: 048, Train Acc: 0.9333, Test Acc: 0.8421
Epoch: 049, Train Acc: 0.9333, Test Acc: 0.8684
Epoch: 050, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 051, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 052, Train Acc: 0.9333, Test Acc: 0.8421
Epoch: 053, Train Acc: 0.9333, Test Acc: 0.8421
Epoch: 054, Train Acc: 0.9267, Test Acc: 0.8158
Epoch: 055, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 056, Train Acc: 0.9333, Test Acc: 0.8421
Epoch: 057, Train Acc: 0.9267, Test Acc: 0.8158
Epoch: 058, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 059, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 060, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 061, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 062, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 063, Train Acc: 0.9267, Test Acc: 0.8684
Epoch: 064, Train Acc: 0.9333, Test Acc: 0.8421
Epoch: 065, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 066, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 067, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 068, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 069, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 070, Train Acc: 0.9133, Test Acc: 0.8158
Epoch: 071, Train Acc: 0.9467, Test Acc: 0.8421
Epoch: 072, Train Acc: 0.9400, Test Acc: 0.8421
Epoch: 073, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 074, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 075, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 076, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 077, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 078, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 079, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 080, Train Acc: 0.9267, Test Acc: 0.8158
Epoch: 081, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 082, Train Acc: 0.9400, Test Acc: 0.8947
Epoch: 083, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 084, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 085, Train Acc: 0.9333, Test Acc: 0.8947
Epoch: 086, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 087, Train Acc: 0.9200, Test Acc: 0.8158
Epoch: 088, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 089, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 090, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 091, Train Acc: 0.9267, Test Acc: 0.8158
Epoch: 092, Train Acc: 0.9267, Test Acc: 0.8421
Epoch: 093, Train Acc: 0.9267, Test Acc: 0.8158
Epoch: 094, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 095, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 096, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 097, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 098, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 099, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 100, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 101, Train Acc: 0.9400, Test Acc: 0.8684
Epoch: 102, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 103, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 104, Train Acc: 0.9267, Test Acc: 0.8684
Epoch: 105, Train Acc: 0.9267, Test Acc: 0.7895
Epoch: 106, Train Acc: 0.8333, Test Acc: 0.7632
Epoch: 107, Train Acc: 0.9267, Test Acc: 0.8684
Epoch: 108, Train Acc: 0.8800, Test Acc: 0.7895
Epoch: 109, Train Acc: 0.8867, Test Acc: 0.8158
Epoch: 110, Train Acc: 0.9133, Test Acc: 0.8684
Epoch: 111, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 112, Train Acc: 0.9267, Test Acc: 0.7895
Epoch: 113, Train Acc: 0.9133, Test Acc: 0.8421
Epoch: 114, Train Acc: 0.9400, Test Acc: 0.8421
Epoch: 115, Train Acc: 0.9333, Test Acc: 0.7632
Epoch: 116, Train Acc: 0.9267, Test Acc: 0.7895
Epoch: 117, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 118, Train Acc: 0.9200, Test Acc: 0.8158
Epoch: 119, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 120, Train Acc: 0.9333, Test Acc: 0.8684
Epoch: 121, Train Acc: 0.9467, Test Acc: 0.8421
Epoch: 122, Train Acc: 0.9467, Test Acc: 0.8421
Epoch: 123, Train Acc: 0.9400, Test Acc: 0.8421
Epoch: 124, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 125, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 126, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 127, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 128, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 129, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 130, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 131, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 132, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 133, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 134, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 135, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 136, Train Acc: 0.9333, Test Acc: 0.7895
Epoch: 137, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 138, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 139, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 140, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 141, Train Acc: 0.9400, Test Acc: 0.8421
Epoch: 142, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 143, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 144, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 145, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 146, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 147, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 148, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 149, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 150, Train Acc: 0.9333, Test Acc: 0.8158
Epoch: 151, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 152, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 153, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 154, Train Acc: 0.9200, Test Acc: 0.7632
Epoch: 155, Train Acc: 0.9267, Test Acc: 0.7895
Epoch: 156, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 157, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 158, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 159, Train Acc: 0.9333, Test Acc: 0.7895
Epoch: 160, Train Acc: 0.9467, Test Acc: 0.7895
Epoch: 161, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 162, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 163, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 164, Train Acc: 0.9400, Test Acc: 0.7895
Epoch: 165, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 166, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 167, Train Acc: 0.9467, Test Acc: 0.8158
Epoch: 168, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 169, Train Acc: 0.9400, Test Acc: 0.8158
Epoch: 170, Train Acc: 0.9467, Test Acc: 0.8158

As one can see, our model reaches around 82% test accuracy. Reasons for the fluctations in accuracy can be explained by the rather small dataset (only 38 test graphs), and usually disappear once one applies GNNs to larger datasets.