[1]:
import toponetx as tnx
import torch
from sklearn.model_selection import train_test_split
from torch_geometric.datasets import TUDataset
from torch_geometric.utils.convert import to_networkx
from topomodelx.nn.hypergraph.unigin import UniGIN
from topomodelx.utils.sparse import from_sparse
torch.manual_seed(0)
If GPU’s are available, we will make use of them. Otherwise, this will run on CPU.
[2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cuda
Train a UNIGIN TNN#
Pre-processing#
Import data#
The first step is to import the dataset, MUTAG, a benchmark dataset for graph classification. We then lift each graph into our domain of choice, a hypergraph.
We will also retrieve: - input signal on the nodes for each of these hypergraphs, as that will be what we feed the model in input - the binary label associated to the hypergraph
[3]:
dataset = TUDataset(root="/tmp/MUTAG", name="MUTAG", use_edge_attr=True)
dataset = dataset[:100]
hg_list = []
x_1_list = []
y_list = []
for graph in dataset:
hg = tnx.SimplicialComplex(to_networkx(graph)).to_hypergraph()
hg_list.append(hg)
x_1_list.append(graph.x.to(device))
y_list.append(graph.y.to(device))
incidence_1_list = []
for hg in hg_list:
incidence_1 = hg.incidence_matrix()
incidence_1 = from_sparse(incidence_1)
incidence_1_list.append(incidence_1.to(device))
Create the Neural Network#
Define the network that initializes the base model and sets up the readout operation. Different downstream tasks might require different pooling procedures.
[4]:
class Network(torch.nn.Module):
"""Network class that initializes the base modelo and readout layer.
Base model parameters:
----------
Reqired:
in_channels : int
Dimension of the input features.
hidden_channels : int
Dimension of the hidden features.
Optitional:
**kwargs : dict
Additional arguments for the base model.
Readout layer parameters:
----------
out_channels : int
Dimension of the output features.
task_level : str
Level of the task. Either "graph" or "node".
"""
def __init__(
self, in_channels, hidden_channels, out_channels, task_level="graph", **kwargs
):
super().__init__()
# Define the model
self.base_model = UniGIN(
in_channels=in_channels, hidden_channels=hidden_channels, **kwargs
)
# Readout
self.linear = torch.nn.Linear(hidden_channels, out_channels)
self.out_pool = task_level == "graph"
def forward(self, x_0, incidence_1):
# Base model
x_0, x_1 = self.base_model(x_0, incidence_1)
# Pool over all nodes in the hypergraph
x = torch.max(x_0, dim=0)[0] if self.out_pool is True else x_0
return self.linear(x)
Initialize the model
[5]:
# Base model hyperparameters
in_channels = x_1_list[0].shape[1]
hidden_channels = 32
n_layers = 3
mlp_num_layers = 1
input_drop = 0.2
layer_drop = 0.2
# Readout hyperparameters
out_channels = 2
task_level = "graph"
model = Network(
in_channels=in_channels,
hidden_channels=hidden_channels,
input_drop=input_drop,
layer_drop=layer_drop,
n_layers=n_layers,
out_channels=out_channels,
task_level=task_level,
).to(device)
Train the Neural Network#
We specify the model, the loss, and an optimizer.
[6]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
crit = torch.nn.CrossEntropyLoss()
Split the dataset into train, val and test sets.
[7]:
x_1_train, x_1_test = train_test_split(x_1_list, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_test = train_test_split(
incidence_1_list, test_size=0.2, shuffle=False
)
y_train, y_test = train_test_split(y_list, test_size=0.2, shuffle=False)
x_1_train, x_1_val = train_test_split(x_1_train, test_size=0.2, shuffle=False)
incidence_1_train, incidence_1_val = train_test_split(
incidence_1_train, test_size=0.2, shuffle=False
)
y_train, y_val = train_test_split(y_train, test_size=0.2, shuffle=False)
Note: The number of epochs below have been kept low to facilitate debugging and testing. Real use cases should likely require more epochs.
[8]:
test_interval = 10
num_epochs = 10
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
loss = 0
for x_1, incidence_1, y in zip(x_1_train, incidence_1_train, y_train, strict=False):
output = model(x_1, incidence_1)
loss += crit(output.unsqueeze(0), y)
loss.backward()
optimizer.step()
if epoch % test_interval == 0:
print(f"Epoch {epoch} loss: {loss.item()}")
model.eval()
with torch.no_grad():
correct = 0
for x_1, incidence_1, y in zip(
x_1_val, incidence_1_val, y_val, strict=False
):
output = model(x_1, incidence_1)
pred = torch.argmax(output)
if pred == y:
correct += 1
print(f"Epoch {epoch} Validation accuracy: {correct / len(y_val)}")
model.eval()
with torch.no_grad():
correct = 0
for x_1, incidence_1, y in zip(x_1_test, incidence_1_test, y_test, strict=False):
output = model(x_1, incidence_1)
pred = torch.argmax(output)
if pred == y:
correct += 1
print(f"Test accuracy: {correct / len(y_test)}")
Epoch 0 loss: 1261.591552734375
Epoch 0 Validation accuracy: 0.5625
Epoch 10 loss: 50.00477600097656
Epoch 10 Validation accuracy: 0.5625
Epoch 20 loss: 37.117366790771484
Epoch 20 Validation accuracy: 0.4375
Epoch 30 loss: 30.90342903137207
Epoch 30 Validation accuracy: 0.3125
Epoch 40 loss: 25.588314056396484
Epoch 40 Validation accuracy: 0.5625
Epoch 50 loss: 26.57889747619629
Epoch 50 Validation accuracy: 0.625
Epoch 60 loss: 19.412574768066406
Epoch 60 Validation accuracy: 0.5625
Test accuracy: 0.7