[1]:
import toponetx as tnx
import torch
from sklearn.model_selection import train_test_split
from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx

from topomodelx.nn.hypergraph.unigin import UniGIN
from topomodelx.utils.sparse import from_sparse

torch.manual_seed(0)

If GPU’s are available, we will make use of them. Otherwise, this will run on CPU.

[2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda

Train a UNIGIN TNN#

Pre-processing#

Import data#

The first step is to import the dataset, MUTAG, a benchmark dataset for graph classification. We then lift each graph into our domain of choice, a hypergraph.

We will also retrieve: - input signal on the nodes for each of these hypergraphs, as that will be what we feed the model in input - the binary label associated to the hypergraph

[3]:
dataset = TUDataset(root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True)
dataset = dataset[:100]
hg_list = []
x_1_list = []
y_list = []
for graph in dataset:
    hg = tnx.SimplicialComplex(to_networkx(graph)).to_hypergraph()
    hg_list.append(hg)
    x_1_list.append(graph.x.to(device))
    y_list.append(graph.y.to(device))

incidence_1_list = []
for hg in hg_list:
    incidence_1 = hg.incidence_matrix()
    incidence_1 = from_sparse(incidence_1)
    incidence_1_list.append(incidence_1.to(device))

Create the Neural Network#

Define the network that initializes the base model and sets up the readout operation. Different downstream tasks might require different pooling procedures.

[4]:
class Network(torch.nn.Module):
    """Network class that initializes the base modelo and readout layer.

    Base model parameters:
    ----------
    Reqired:
    in_channels : int
        Dimension of the input features.
    hidden_channels : int
        Dimension of the hidden features.

    Optitional:
    **kwargs : dict
        Additional arguments for the base model.

    Readout layer parameters:
    ----------
    out_channels : int
        Dimension of the output features.
    task_level : str
        Level of the task. Either "graph" or "node".
    """

    def __init__(
        self, in_channels, hidden_channels, out_channels, task_level="graph", **kwargs
    ):
        super().__init__()

        # Define the model
        self.base_model = UniGIN(
            in_channels=in_channels, hidden_channels=hidden_channels, **kwargs
        )

        # Readout
        self.linear = torch.nn.Linear(hidden_channels, out_channels)
        self.out_pool = task_level == "graph"

    def forward(self, x_0, incidence_1):
        # Base model
        x_0, x_1 = self.base_model(x_0, incidence_1)

        # Pool over all nodes in the hypergraph
        x = torch.max(x_0, dim=0)[0] if self.out_pool is True else x_0

        return self.linear(x)

Initialize the model

[5]:
# Base model hyperparameters
in_channels = x_1_list[0].shape[1]
hidden_channels = 32
n_layers = 3
mlp_num_layers = 1
input_drop = 0.2
layer_drop = 0.2

# Readout hyperparameters
out_channels = 2
task_level = "graph"


model = Network(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    input_drop=input_drop,
    layer_drop=layer_drop,
    n_layers=n_layers,
    out_channels=out_channels,
    task_level=task_level,
).to(device)

Train the Neural Network#

We specify the model, the loss, and an optimizer.

[6]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
crit = torch.nn.CrossEntropyLoss()

Split the dataset into train, val and test sets.

[7]:
x_1_train, x_1_test = train_test_split(x_1_list, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=0.2, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=0.2, shuffle=False)

x_1_train, x_1_val = train_test_split(x_1_train, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_val = train_test_split(
    incidence_1_train, test_size=0.2, shuffle=False
)
y_train, y_val = train_test_split(y_train, test_size=0.2, shuffle=False)

Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.

[8]:
test_interval = 10
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    optimizer.zero_grad()
    loss = 0
    for x_1, incidence_1, y in zip(x_1_train, incidence_1_train, y_train, strict=False):
        output = model(x_1, incidence_1)
        loss += crit(output.unsqueeze(0), y)
    loss.backward()
    optimizer.step()
    if epoch % test_interval == 0:
        print(f"Epoch {epoch} loss: {loss.item()}")
        model.eval()
        with torch.no_grad():
            correct = 0
            for x_1, incidence_1, y in zip(
                x_1_val, incidence_1_val, y_val, strict=False
            ):
                output = model(x_1, incidence_1)
                pred = torch.argmax(output)
                if pred == y:
                    correct += 1
            print(f"Epoch {epoch} Validation accuracy: {correct / len(y_val)}")

model.eval()
with torch.no_grad():
    correct = 0
    for x_1, incidence_1, y in zip(x_1_test, incidence_1_test, y_test, strict=False):
        output = model(x_1, incidence_1)
        pred = torch.argmax(output)
        if pred == y:
            correct += 1
    print(f"Test accuracy: {correct / len(y_test)}")
Epoch 0 loss: 1261.591552734375
Epoch 0 Validation accuracy: 0.5625
Epoch 10 loss: 50.00477600097656
Epoch 10 Validation accuracy: 0.5625
Epoch 20 loss: 37.117366790771484
Epoch 20 Validation accuracy: 0.4375
Epoch 30 loss: 30.90342903137207
Epoch 30 Validation accuracy: 0.3125
Epoch 40 loss: 25.588314056396484
Epoch 40 Validation accuracy: 0.5625
Epoch 50 loss: 26.57889747619629
Epoch 50 Validation accuracy: 0.625
Epoch 60 loss: 19.412574768066406
Epoch 60 Validation accuracy: 0.5625
Test accuracy: 0.7