Train a Combinatorial Complex Attention Neural Network for Mesh Classification.#

We create and train a mesh classification high order attentional neural network operating over combinatorial complexes. The model was introduced in Figure 35(b), Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023).

The Neural Network:#

The neural network is composed of a sequence of identical attention layers for a dimension two combinatorial complex, a final fully connected layer embedding the features into a common space, and a final transformation to a vector with probabilities for each class. Each attention layer is composed of two levels. In both levels, messages computed for the cells of identical dimension are aggregated using a sum operation. All the messages are computed using the attention mechanisms for squared and non-squared neighborhoods presented in Definitions 31, 32, and 33, Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023). The following message passing scheme is followed in each of the levels for each layer:

  1. First level:

πŸŸ₯ \(\quad m^{0\rightarrow 0}_{y\rightarrow x} = \left((A_{\uparrow, 0})_{xy} \cdot \text{att}_{xy}^{0\rightarrow 0}\right) h_y^{t,(0)} \Theta^t_{0\rightarrow 0}\)

πŸŸ₯ \(\quad m^{0\rightarrow 1}_{y\rightarrow x} = \left((B_{1}^T)_{xy} \cdot \text{att}_{xy}^{0\rightarrow 1}\right) h_y^{t,(0)} \Theta^t_{0\rightarrow 1}\)

πŸŸ₯ \(\quad m^{1\rightarrow 0}_{y\rightarrow x} = \left((B_{1})_{xy} \cdot \text{att}_{xy}^{1\rightarrow 0}\right) h_y^{t,(1)} \Theta^t_{1\rightarrow 0}\)

πŸŸ₯ \(\quad m^{1\rightarrow 2}_{y\rightarrow x} = \left((B_{2}^T)_{xy} \cdot \text{att}_{xy}^{1\rightarrow 2}\right) h_y^{t,(1)} \Theta^t_{1\rightarrow 2}\)

πŸŸ₯ \(\quad m^{2\rightarrow 1}_{y\rightarrow x} = \left((B_{2})_{xy} \cdot \text{att}_{xy}^{2\rightarrow 1}\right) h_y^{t,(2)} \Theta^t_{2\rightarrow 1}\)

🟧 \(\quad m^{0\rightarrow 0}_{x}=\phi_u\left(\sum_{y\in A_{\uparrow, 0}(x)} m^{0\rightarrow 0}_{y\rightarrow x}\right)\)

🟧 \(\quad m^{0\rightarrow 1}_{x}=\phi_u\left(\sum_{y\in B_{1}^T(x)} m^{0\rightarrow 1}_{y\rightarrow x}\right)\)

🟧 \(\quad m^{1\rightarrow 0}_{x}=\phi_u\left(\sum_{y\in B_{1}(x)} m^{1\rightarrow 0}_{y\rightarrow x}\right)\)

🟧 \(\quad m^{1\rightarrow 2}_{x}=\phi_u\left(\sum_{y\in B_{2}^T(x)} m^{1\rightarrow 2}_{y\rightarrow x}\right)\)

🟧 \(\quad m^{2\rightarrow 1}_{x}=\phi_u\left(\sum_{y\in B_{2}(x)} m^{2\rightarrow 1}_{y\rightarrow x}\right)\)

🟩 \(\quad m_x^{(0)}=\phi_a\left(m^{0\rightarrow 0}_{x}+m^{1\rightarrow 0}_{x}\right)\)

🟩 \(\quad m_x^{(1)}=\phi_a\left(m^{0\rightarrow 1}_{x}+m^{2\rightarrow 1}_{x}\right)\)

🟩 \(\quad m_x^{(2)}=\phi_a\left(m^{1\rightarrow 2}_{x}\right)\)

🟦 \(\quad i_x^{t,(0)} = m_x^{(0)}\)

🟦 \(\quad i_x^{t,(1)} = m_x^{(1)}\)

🟦 \(\quad i_x^{t,(2)} = m_x^{(2)}\)

where \(i_x^{t,(\cdot)}\) represents intermediate feature vectors.

  1. Second level:

πŸŸ₯ \(\quad m^{0\rightarrow 0}_{y\rightarrow x} = \left((A_{\uparrow, 0})_{xy} \cdot \text{att}_{xy}^{0\rightarrow 0}\right) i_y^{t,(0)} \Theta^t_{0\rightarrow 0}\)

πŸŸ₯ \(\quad m^{1\rightarrow 1}_{y\rightarrow x} = \left((A_{\uparrow, 1})_{xy} \cdot \text{att}_{xy}^{1\rightarrow 1}\right) i_y^{t,(1)} \Theta^t_{1\rightarrow 1}\)

πŸŸ₯ \(\quad m^{2\rightarrow 2}_{y\rightarrow x} = \left((A_{\downarrow, 2})_{xy} \cdot \text{att}_{xy}^{2\rightarrow 2}\right) i_y^{t,(2)} \Theta^t_{2\rightarrow 2}\)

πŸŸ₯ \(\quad m^{0\rightarrow 1}_{y\rightarrow x} = \left((B_{1}^T)_{xy} \cdot \text{att}_{xy}^{0\rightarrow 1}\right) i_y^{t,(0)} \Theta^t_{0\rightarrow 1}\)

πŸŸ₯ \(\quad m^{1\rightarrow 2}_{y\rightarrow x} = \left((B_{2}^T)_{xy} \cdot \text{att}_{xy}^{1\rightarrow 2}\right) i_y^{t,(1)} \Theta^t_{1\rightarrow 2}\)

🟧 \(\quad m^{0\rightarrow 0}_{x}=\phi_u\left(\sum_{y\in A_{\uparrow, 0}(x)} m^{0\rightarrow 0}_{y\rightarrow x}\right)\)

🟧 \(\quad m^{1\rightarrow 1}_{x}=\phi_u\left(\sum_{y\in A_{\uparrow, 1}(x)} m^{1\rightarrow 1}_{y\rightarrow x}\right)\)

🟧 \(\quad m^{2\rightarrow 2}_{x}=\phi_u\left(\sum_{y\in A_{\downarrow, 2}(x)} m^{2\rightarrow 2}_{y\rightarrow x}\right)\)

🟧 \(\quad m^{0\rightarrow 1}_{x}=\phi_u\left(\sum_{y\in B_{1}^T(x)} m^{0\rightarrow 1}_{y\rightarrow x}\right)\)

🟧 \(\quad m^{1\rightarrow 2}_{x}=\phi_u\left(\sum_{y\in B_{2}^T(x)} m^{1\rightarrow 2}_{y\rightarrow x}\right)\)

🟩 \(\quad m_x^{(0)}=\phi_a\left(m^{0\rightarrow 0}_{x}+m^{1\rightarrow 0}_{x}\right)\)

🟩 \(\quad m_x^{(1)}=\phi_a\left(m^{1\rightarrow 1}_{x} + m^{0\rightarrow 1}_{x}\right)\)

🟩 \(\quad m_x^{(2)}=\phi_a\left(m^{1\rightarrow 2}_{x} + m^{2\rightarrow 2}_{x}\right)\)

🟦 \(\quad h_x^{t+1,(0)} = m_x^{(0)}\)

🟦 \(\quad h_x^{t+1,(1)} = m_x^{(1)}\)

🟦 \(\quad h_x^{t+1,(2)} = m_x^{(2)}\)

In both message passing levels, \(\phi_u\) and \(\phi_a\) represent common activation functions for within and between neighborhood aggregations, respectively. Also, \(\Theta\) and \(\text{att}\) represent learnable weights and attention matrices, respectively, that are different in each level. Attention matrices are introduced in Figure 35(b), Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023). In this implementation, attention matrices are computed using the LeakyReLU activation function, as in previous versions of the paper. We give more information about the actual implementation of the neural network in this notebook in the following sections.

Notations, adjacency, coadjacency, and incidence matrices are defined in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023). The tensor diagram for the layer can be found in the first column and last row of Figure 11, from the same paper.

The Task:#

We train this model to perform entire mesh classification on `SHREC 2016 from the ShapeNet Dataset <http://shapenet.cs.stanford.edu/shrec16/>`__. This dataset contains 480 3D mesh samples belonging to 30 distinct classes and represented as simplicial complexes.

Each mesh contains a set of vertices, edges, and faces. Each of the latter entities have a set of features associated to them:

  • Node features \(v \in \mathbb{R}^6\) defined as the direct sum of the following features:

    • Position \(p_v \in \mathbb{R}^3\) coordinates.

    • Normal \(n_v \in \mathbb{R}^3\) coordinates.

  • Edge features \(e \in \mathbb{R}^{10}\) defined as the direct sum of the following features:

    • Dihedral angle \(\phi \in \mathbb{R}\).

    • Edge span \(l \in \mathbb{R}\).

    • 2 edge angle in the triangle that \(\theta_e \in \mathbb{R}^2\).

    • 6 edge ratios \(r \in \mathbb{R}^6\).

  • Face features

    • Face area \(a \in \mathbb{R}\).

    • Face normal \(n_f \in \mathbb{R}^3\).

    • 3 face angles \(\theta_f \in \mathbb{R}^3\).

We lift the simplicial complexes representing each mesh to a topologically equivalent combinatorial complex representation.

The task is to predict the class that a certain mesh belongs to, given its combinatorial complex representation. For this purpose we implement the Higher Order Attention Model for Mesh Classification first introduced in Hajij et. al : Topological Deep Learning: Going Beyond Graph Data (2023).

Set-up#

[1]:
import numpy as np
import toponetx as tnx
import torch
from torch.utils.data import DataLoader, Dataset

from topomodelx.nn.combinatorial.hmc import HMC

If GPU’s are available, we will make use of them. Otherwise, this will run on CPU.

[2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
cpu

Pre-processing#

Import data#

We first create a class for the SHREC 2016 dataset. This class will be used to load the data and create the necessary neighborhood matrices for each combinatorial complex in the dataset.

[3]:
class SHRECDataset(Dataset):
    """Class for the SHREC 2016 dataset.

    Parameters
    ----------
    data : npz file
        npz file containing the SHREC 2016 data.
    """

    def __init__(self, data) -> None:
        self.complexes = [cc.to_combinatorial_complex() for cc in data["complexes"]]
        self.x_0 = data["node_feat"]
        self.x_1 = data["edge_feat"]
        self.x_2 = data["face_feat"]
        self.y = data["label"]
        self.a0, self.a1, self.coa2, self.b1, self.b2 = self._get_neighborhood_matrix()

    def _get_neighborhood_matrix(self) -> list[list[torch.sparse.Tensor], ...]:
        """Neighborhood matrices for each combinatorial complex in the dataset.

        Following the Higher Order Attention Model for Mesh Classification message passing scheme, this method computes the necessary neighborhood matrices
        for each combinatorial complex in the dataset. This method computes:

        - Adjacency matrices for each 0-cell in the dataset.
        - Adjacency matrices for each 1-cell in the dataset.
        - Coadjacency matrices for each 2-cell in the dataset.
        - Incidence matrices from 1-cells to 0-cells for each 1-cell in the dataset.
        - Incidence matrices from 2-cells to 1-cells for each 2-cell in the dataset.

        Returns
        -------
        a0 : list of torch.sparse.FloatTensor
            Adjacency matrices for each 0-cell in the dataset.
        a1 : list of torch.sparse.FloatTensor
            Adjacency matrices for each 1-cell in the dataset.
        coa2 : list of torch.sparse.FloatTensor
            Coadjacency matrices for each 2-cell in the dataset.
        b1 : list of torch.sparse.FloatTensor
            Incidence matrices from 1-cells to 0-cells for each 1-cell in the dataset.
        b2 : list of torch.sparse.FloatTensor
            Incidence matrices from 2-cells to 1-cells for each 2-cell in the dataset.
        """

        a0 = []
        a1 = []
        coa2 = []
        b1 = []
        b2 = []

        for cc in self.complexes:
            a0.append(torch.from_numpy(cc.adjacency_matrix(0, 1).todense()).to_sparse())
            a1.append(torch.from_numpy(cc.adjacency_matrix(1, 2).todense()).to_sparse())

            B = cc.incidence_matrix(rank=1, to_rank=2)
            A = B.T @ B
            A.setdiag(0)
            coa2.append(torch.from_numpy(A.todense()).to_sparse())

            b1.append(torch.from_numpy(cc.incidence_matrix(0, 1).todense()).to_sparse())
            b2.append(torch.from_numpy(cc.incidence_matrix(1, 2).todense()).to_sparse())

        return a0, a1, coa2, b1, b2

    def num_classes(self) -> int:
        """Returns the number of classes in the dataset.

        Returns
        -------
        int
            Number of classes in the dataset.
        """
        return len(np.unique(self.y))

    def channels_dim(self) -> tuple[int, int, int]:
        """Returns the number of channels for each input signal.

        Returns
        -------
        tuple of int
            Number of channels for each input signal.
        """
        return [self.x_0[0].shape[1], self.x_1[0].shape[1], self.x_2[0].shape[1]]

    def __len__(self) -> int:
        """Returns the number of elements in the dataset.

        Returns
        -------
        int
            Number of elements in the dataset.
        """
        return len(self.complexes)

    def __getitem__(self, idx) -> tuple[torch.Tensor, ...]:
        """Returns the idx-th element in the dataset.

        Parameters
        ----------
        idx : int
            Index of the element to return.

        Returns
        -------
        tuple of torch.Tensor
            Tuple containing the idx-th element in the dataset, including the input signals on nodes, edges and faces, the neighborhood matrices and the label.
        """
        return (
            self.x_0[idx],
            self.x_1[idx],
            self.x_2[idx],
            self.a0[idx],
            self.a1[idx],
            self.coa2[idx],
            self.b1[idx],
            self.b2[idx],
            self.y[idx],
        )

We load the data.

[4]:
shrec_training, shrec_testing = tnx.datasets.shrec_16()
Loading shrec 16 full dataset...

done!

Creating the train dataset and dataloader.

[5]:
training_dataset = SHRECDataset(shrec_training)
training_dataloader = DataLoader(training_dataset, batch_size=1, shuffle=True)

Creating the train dataset and dataloader.

[6]:
testing_dataset = SHRECDataset(shrec_testing)
testing_dataloader = DataLoader(testing_dataset, batch_size=1, shuffle=True)

Create the Neural Network#

The task is to classify the meshes into their corresponding classes. To address this, we employ the Higher Order Attention Network Model for Mesh Classification, as outlined in the article Higher Order Attention Networks. This model integrates a hierarchical and attention-based message passing scheme as per the article’s descriptions. In addition, the model utilizes a final sum pooling layer which effectively maps the nodal, edge, and face features of the meshes into a shared N-dimensional Euclidean space, where N represents the number of different classes.

Train the Neural Network#

We create the trainer class. The model is trained using the Adam optimizer and the Cross Entropy Loss function.

[7]:
class Trainer:
    """Trainer for the HOANMeshClassifier.

    Parameters
    ----------
    model : torch.nn.Module
        The model to train.
    training_dataloader : torch.utils.data.DataLoader
        The dataloader for the training set.
    testing_dataloader : torch.utils.data.DataLoader
        The dataloader for the testing set.
    learning_rate : float
        The learning rate for the Adam optimizer.
    device : torch.device
        The device to use for training.
    """

    def __init__(
        self, model, training_dataloader, testing_dataloader, learning_rate, device
    ) -> None:
        self.model = model.to(device)
        self.training_dataloader = training_dataloader
        self.testing_dataloader = testing_dataloader
        self.device = device
        self.crit = torch.nn.CrossEntropyLoss()
        self.opt = torch.optim.Adam(model.parameters(), lr=learning_rate)

    def _to_device(self, x) -> list[torch.Tensor]:
        """Converts tensors to the correct type and moves them to the device.

        Parameters
        ----------
        x : List[torch.Tensor]
            List of tensors to convert.
        Returns
        -------
        List[torch.Tensor]
            List of converted tensors to float type and moved to the device.
        """

        return [el[0].float().to(self.device) for el in x]

    def train(self, num_epochs=500, test_interval=25) -> None:
        """Trains the model for the specified number of epochs.

        Parameters
        ----------
        num_epochs : int
            Number of epochs to train.
        test_interval : int
            Interval between testing epochs.
        """
        for epoch_i in range(num_epochs):
            training_accuracy, epoch_loss = self._train_epoch()
            print(
                f"Epoch: {epoch_i} loss: {epoch_loss:.4f} Train_acc: {training_accuracy:.4f}",
                flush=True,
            )
            if (epoch_i + 1) % test_interval == 0:
                test_accuracy = self.validate()
                print(f"Test_acc: {test_accuracy:.4f}", flush=True)

    def _train_epoch(self) -> tuple[float, float]:
        """Trains the model for one epoch.

        Returns
        -------
        training_accuracy : float
            The mean training accuracy for the epoch.
        epoch_loss : float
            The mean loss for the epoch.
        """
        training_samples = len(self.training_dataloader.dataset)
        total_loss = 0
        correct = 0
        self.model.train()
        for sample in self.training_dataloader:
            (
                x_0,
                x_1,
                x_2,
                adjacency_0,
                adjacency_1,
                coadjacency_2,
                incidence_1,
                incidence_2,
            ) = self._to_device(sample[:-1])

            self.opt.zero_grad()

            y_hat = self.model.forward(
                x_0,
                x_1,
                x_2,
                adjacency_0,
                adjacency_1,
                coadjacency_2,
                incidence_1,
                incidence_2,
            )

            y = sample[-1][0].long().to(self.device)
            total_loss += self._compute_loss_and_update(y_hat, y)
            correct += (y_hat.argmax() == y).sum().item()

        training_accuracy = correct / training_samples
        epoch_loss = total_loss / training_samples

        return training_accuracy, epoch_loss

    def _compute_loss_and_update(self, y_hat, y) -> float:
        """Computes the loss, performs backpropagation, and updates the model's parameters.

        Parameters
        ----------
        y_hat : torch.Tensor
            The output of the model.
        y : torch.Tensor
            The ground truth.

        Returns
        -------
        loss: float
            The loss value.
        """

        loss = self.crit(y_hat, y)
        loss.backward()
        self.opt.step()
        return loss.item()

    def validate(self) -> float:
        """Validates the model using the testing dataloader.

        Returns
        -------
        test_accuracy : float
            The mean testing accuracy.
        """
        correct = 0
        self.model.eval()
        test_samples = len(self.testing_dataloader.dataset)
        with torch.no_grad():
            for sample in self.testing_dataloader:
                (
                    x_0,
                    x_1,
                    x_2,
                    adjacency_0,
                    adjacency_1,
                    coadjacency_2,
                    incidence_1,
                    incidence_2,
                ) = self._to_device(sample[:-1])

                y_hat = self.model(
                    x_0,
                    x_1,
                    x_2,
                    adjacency_0,
                    adjacency_1,
                    coadjacency_2,
                    incidence_1,
                    incidence_2,
                )
                y = sample[-1][0].long().to(self.device)
                correct += (y_hat.argmax() == y).sum().item()
            return correct / test_samples

We generate our Network, combining HOAN model with the appropriate readout for the considered task

[8]:
class Network(torch.nn.Module):
    def __init__(
        self,
        channels_per_layer,
        negative_slope=0.2,
        num_classes=2,
    ):
        super().__init__()
        self.base_model = HMC(
            channels_per_layer,
            negative_slope,
        )
        self.l0 = torch.nn.Linear(channels_per_layer[-1][2][0], num_classes)
        self.l1 = torch.nn.Linear(channels_per_layer[-1][2][1], num_classes)
        self.l2 = torch.nn.Linear(channels_per_layer[-1][2][2], num_classes)

    def forward(
        self,
        x_0,
        x_1,
        x_2,
        neighborhood_0_to_0,
        neighborhood_1_to_1,
        neighborhood_2_to_2,
        neighborhood_0_to_1,
        neighborhood_1_to_2,
    ):
        x_0, x_1, x_2 = self.base_model(
            x_0,
            x_1,
            x_2,
            neighborhood_0_to_0,
            neighborhood_1_to_1,
            neighborhood_2_to_2,
            neighborhood_0_to_1,
            neighborhood_1_to_2,
        )
        x_0 = self.l0(x_0)
        x_1 = self.l1(x_1)
        x_2 = self.l2(x_2)

        # Sum all the elements in the dimension zero
        x_0 = torch.nanmean(x_0, dim=0)
        x_1 = torch.nanmean(x_1, dim=0)
        x_2 = torch.nanmean(x_2, dim=0)

        return x_0 + x_1 + x_2

We define the parameters for the model. We use softmax activation for the attention layers. Moreover, we use relu activation for the update and the aggregation steps. We set the negative slope parameter for the Leaky ReLU activation to 0.2. We only use one higher order attention layer as it already achieves almost perfect test accuracy, although more layers could be added.

[9]:
in_channels = training_dataset.channels_dim()
intermediate_channels = [60, 60, 60]
final_channels = [60, 60, 60]

channels_per_layer = [[in_channels, intermediate_channels, final_channels]]
# defube HOAN mesh classifier
model = Network(
    channels_per_layer, negative_slope=0.2, num_classes=training_dataset.num_classes()
)

# If GPU's are available, we will make use of them. Otherwise, this will run on CPU.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

trainer = Trainer(model, training_dataloader, testing_dataloader, 0.001, device)
[10]:
model
[10]:
Network(
  (base_model): HMC(
    (layers): ModuleList(
      (0): HMCLayer(
        (hbs_0_level1): HBS(
          (weight): ParameterList(  (0): Parameter containing: [torch.float32 of size 6x60])
          (att_weight): ParameterList(  (0): Parameter containing: [torch.float32 of size 120x1])
        )
        (hbns_0_1_level1): HBNS()
        (hbns_1_2_level1): HBNS()
        (hbs_0_level2): HBS(
          (weight): ParameterList(  (0): Parameter containing: [torch.float32 of size 60x60])
          (att_weight): ParameterList(  (0): Parameter containing: [torch.float32 of size 120x1])
        )
        (hbns_0_1_level2): HBNS()
        (hbs_1_level2): HBS(
          (weight): ParameterList(  (0): Parameter containing: [torch.float32 of size 60x60])
          (att_weight): ParameterList(  (0): Parameter containing: [torch.float32 of size 120x1])
        )
        (hbns_1_2_level2): HBNS()
        (hbs_2_level2): HBS(
          (weight): ParameterList(  (0): Parameter containing: [torch.float32 of size 60x60])
          (att_weight): ParameterList(  (0): Parameter containing: [torch.float32 of size 120x1])
        )
        (aggr): Aggregation()
      )
    )
  )
  (l0): Linear(in_features=60, out_features=30, bias=True)
  (l1): Linear(in_features=60, out_features=30, bias=True)
  (l2): Linear(in_features=60, out_features=30, bias=True)
)

We train the HoanMeshClassifier using low amount of epochs: we keep training minimal for the purpose of rapid testing.

[11]:
trainer.train(num_epochs=5, test_interval=1)
/Users/gbg141/Documents/Projects/TopoModelX/topomodelx/nn/combinatorial/hmc_layer.py:683: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/SparseCsrTensorImpl.cpp:56.)
  A_p = torch.sparse.mm(A_p, neighborhood)
Epoch: 0 loss: 3.5569 Train_acc: 0.0292
Test_acc: 0.0667
Epoch: 1 loss: 3.2807 Train_acc: 0.0688
Test_acc: 0.1583
Epoch: 2 loss: 2.9899 Train_acc: 0.1125
Test_acc: 0.1417
Epoch: 3 loss: 2.6567 Train_acc: 0.1792
Test_acc: 0.1583
Epoch: 4 loss: 2.3474 Train_acc: 0.2583
Test_acc: 0.3250

Letting the model train for longer, we can see that the model achieves an outstanding performance on both the training and testing sets.

[10]:
# trainer.train(num_epochs=30, test_interval=10)