Train a SCCNN#

In this notebook, we will create and train a convolutional neural network in the simplicial complex domain, as proposed in the paper by Yang et. al : Convolutional Learning on Simplicial Complexes (2023).

We train the model to perform:#

1.  Complex classification using the shrec16 benchmark dataset.
2.  Node classification using the karate dataset

Simplicial Complex Convolutional Neural Networks [SCCNN]#

SCCNN extends the SCNN to the complex domain by accounting for inter-simplicial connections, i.e., contributions from simplices of adjacent orders.

For example, we consider SCCNN layers in an SC of order two. At layer \(t\), given the inputs on nodes, edges and faces, \(\mathbf{h}_{t}^0,\mathbf{h}_{t}^1\) and \(\mathbf{h}_{t}^2\), the SCCNN layer contains the following

\[\mathbf{h}_{t+1}^1 = \sigma \bigg[ \mathbf{F}_{t,\downarrow} \mathbf{B}_{1}^\top \mathbf{h}_{t}^{0} + \mathbf{F}_{t} \mathbf{h}_t^1 + \mathbf{F}_{t,\uparrow} \mathbf{B}_{2} \mathbf{h}_t^{2} \bigg]\]

where \(\mathbf{F}_t\) is the simplicial convolutional filter defined in the edge space, and \(\mathbf{F}_{t,\downarrow}\) and \(\mathbf{F}_{t,\uparrow}\) are the convolutional filters based on, respectively, only the lower and upper Laplacians. They are given by

\[\mathbf{F}_{t} = {\theta}_t + \sum_{p_d=1}^{P_d} {\theta}_{t,p_d} (\mathbf{L}_{\downarrow,1})^{p_d} + \sum_{p_u=1}^{P_u} {\theta}_{t,p_u} (\mathbf{L}_{\uparrow,1})^{p_u}\]
\[\mathbf{F}_{t,\downarrow} = {\theta}_t + \sum_{p_d=1}^{P_d} {\theta}_{t,p_d} (\mathbf{L}_{\downarrow,1})^{p_d} \text{ and } \mathbf{F}_{t,\uparrow} = {\theta}_t + \sum_{p_u=1}^{P_u} {\theta}_{t,p_u} (\mathbf{L}_{\uparrow,1})^{p_u}\]

Likewise, for the node output, we have

\[\mathbf{h}_{t+1}^0 = \sigma \bigg[ \mathbf{F}_{t} \mathbf{h}_t^0 + \mathbf{F}_{t,\uparrow} \mathbf{B}_{1} \mathbf{h}_t^{1} \bigg]\]

where \(\mathbf{F}_t\) and \(\mathbf{F}_{t,\uparrow}\) are two graph filters essentially.

For the face output, we have

\[\mathbf{h}_{t+1}^2 = \sigma \bigg[ \mathbf{F}_{t} \mathbf{h}_{t}^{2} + \mathbf{F}_{t,\downarrow} \mathbf{B}_{2}^\top \mathbf{h}_t^{1} \bigg]\]

where \(\mathbf{F}_t\) and \(\mathbf{F}_{t,\downarrow}\) are two simplicial filters defined in the triangle (face) space.

1. Complex Classification#

 In [18]:
import numpy as np
import toponetx.datasets as datasets
import torch
from sklearn.model_selection import train_test_split

from topomodelx.nn.simplicial.sccnn import SCCNN
from topomodelx.utils.sparse import from_sparse

%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Import shrec dataset#

We must first lift our graph dataset into the simplicial complex domain.

 In [19]:
shrec, _ = datasets.mesh.shrec_16(size="small")
shrec = {key: np.array(value) for key, value in shrec.items()}
x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]
Loading shrec 16 small dataset...

done!
 In [20]:
in_channels_0 = x_0s[-1].shape[1]
in_channels_1 = x_1s[-1].shape[1]
in_channels_2 = x_2s[-1].shape[1]

in_channels_all = (in_channels_0, in_channels_1, in_channels_2)
print(in_channels_all)
(6, 10, 7)

Define Neighborhood Strctures#

Get incidence matrices \(\mathbf{B}_1,\mathbf{B}_2\) and Hodge Laplacians \(\mathbf{L}_0, \mathbf{L}_1\) and \(\mathbf{L}_2\).

Note that the original paper considered the weighted versions of these operators. However, the current TOPONETX package does not provide such feature yet.

 In [21]:
max_rank = 2  # the order of the SC is two
incidence_1_list = []
incidence_2_list = []

laplacian_0_list = []
laplacian_down_1_list = []
laplacian_up_1_list = []
laplacian_2_list = []

for simplex in simplexes:
    incidence_1 = simplex.incidence_matrix(rank=1)
    incidence_2 = simplex.incidence_matrix(rank=2)
    laplacian_0 = simplex.hodge_laplacian_matrix(rank=0)
    laplacian_down_1 = simplex.down_laplacian_matrix(rank=1)
    laplacian_up_1 = simplex.up_laplacian_matrix(rank=1)
    laplacian_2 = simplex.hodge_laplacian_matrix(rank=2)

    incidence_1 = from_sparse(incidence_1)
    incidence_2 = from_sparse(incidence_2)
    laplacian_0 = from_sparse(laplacian_0)
    laplacian_down_1 = from_sparse(laplacian_down_1)
    laplacian_up_1 = from_sparse(laplacian_up_1)
    laplacian_2 = from_sparse(laplacian_2)

    incidence_1_list.append(incidence_1)
    incidence_2_list.append(incidence_2)
    laplacian_0_list.append(laplacian_0)
    laplacian_down_1_list.append(laplacian_down_1)
    laplacian_up_1_list.append(laplacian_up_1)
    laplacian_2_list.append(laplacian_2)

Create and Train the Neural Network#

We specify the model with our pre-made neighborhood structures and specify an optimizer.

 In [22]:
class Network(torch.nn.Module):
    def __init__(
        self,
        in_channels_all,
        hidden_channels_all,
        out_channels,
        conv_order,
        max_rank,
        n_layers=2,
    ):
        super().__init__()
        self.base_model = SCCNN(
            in_channels_all=in_channels_all,
            hidden_channels_all=hidden_channels_all,
            conv_order=conv_order,
            sc_order=max_rank,
            n_layers=n_layers,
        )
        out_channels_0, out_channels_1, out_channels_2 = hidden_channels_all
        self.out_linear_0 = torch.nn.Linear(out_channels_0, out_channels)
        self.out_linear_1 = torch.nn.Linear(out_channels_1, out_channels)
        self.out_linear_2 = torch.nn.Linear(out_channels_2, out_channels)

    def forward(self, x_all, laplacian_all, incidence_all):
        x_all = self.base_model(x_all, laplacian_all, incidence_all)
        x_0, x_1, x_2 = x_all

        x_0 = self.out_linear_0(x_0)
        x_1 = self.out_linear_1(x_1)
        x_2 = self.out_linear_2(x_2)

        # Take the average of the 2D, 1D, and 0D cell features. If they are NaN, convert them to 0.
        two_dimensional_cells_mean = torch.nanmean(x_2, dim=0)
        two_dimensional_cells_mean[torch.isnan(two_dimensional_cells_mean)] = 0
        one_dimensional_cells_mean = torch.nanmean(x_1, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        zero_dimensional_cells_mean = torch.nanmean(x_0, dim=0)
        zero_dimensional_cells_mean[torch.isnan(zero_dimensional_cells_mean)] = 0
        # Return the sum of the averages
        return (
            two_dimensional_cells_mean
            + one_dimensional_cells_mean
            + zero_dimensional_cells_mean
        )
 In [23]:
conv_order = 2
intermediate_channels_all = (16, 16, 16)
num_layers = 2
out_channels = 1  # num classes

model = Network(
    in_channels_all=in_channels_all,
    hidden_channels_all=intermediate_channels_all,
    out_channels=out_channels,
    conv_order=conv_order,
    max_rank=max_rank,
    n_layers=num_layers,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss(size_average=True, reduction="mean")
print(model)
Network(
  (base_model): SCCNN(
    (in_linear_0): Linear(in_features=6, out_features=16, bias=True)
    (in_linear_1): Linear(in_features=10, out_features=16, bias=True)
    (in_linear_2): Linear(in_features=7, out_features=16, bias=True)
    (layers): ModuleList(
      (0-1): 2 x SCCNNLayer()
    )
  )
  (out_linear_0): Linear(in_features=16, out_features=1, bias=True)
  (out_linear_1): Linear(in_features=16, out_features=1, bias=True)
  (out_linear_2): Linear(in_features=16, out_features=1, bias=True)
)
/Users/gbg141/Documents/TopoProjectX/TopoModelX/venv_modelx/lib/python3.11/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='mean' instead.
  warnings.warn(warning.format(ret))
 In [24]:
test_size = 0.2
x_0_train, x_0_test = train_test_split(x_0s, test_size=test_size, shuffle=False)
x_1_train, x_1_test = train_test_split(x_1s, test_size=test_size, shuffle=False)
x_2_train, x_2_test = train_test_split(x_2s, test_size=test_size, shuffle=False)

incidence_1_train, incidence_1_test = train_test_split(
    incidence_1_list, test_size=test_size, shuffle=False
)
incidence_2_train, incidence_2_test = train_test_split(
    incidence_2_list, test_size=test_size, shuffle=False
)
laplacian_0_train, laplacian_0_test = train_test_split(
    laplacian_0_list, test_size=test_size, shuffle=False
)
laplacian_down_1_train, laplacian_down_1_test = train_test_split(
    laplacian_down_1_list, test_size=test_size, shuffle=False
)
laplacian_up_1_train, laplacian_up_1_test = train_test_split(
    laplacian_up_1_list, test_size=test_size, shuffle=False
)
laplacian_2_train, laplacian_2_test = train_test_split(
    laplacian_2_list, test_size=test_size, shuffle=False
)

y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)

We train the SCCNN using low amount of epochs: we keep training minimal for the purpose of rapid testing.

 In [25]:
test_interval = 1
num_epochs = 5

for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for (
        x_0,
        x_1,
        x_2,
        incidence_1,
        incidence_2,
        laplacian_0,
        laplacian_down_1,
        laplacian_up_1,
        laplacian_2,
        y,
    ) in zip(
        x_0_train,
        x_1_train,
        x_2_train,
        incidence_1_train,
        incidence_2_train,
        laplacian_0_train,
        laplacian_down_1_train,
        laplacian_up_1_train,
        laplacian_2_train,
        y_train,
        strict=False,
    ):
        x_0 = torch.tensor(x_0)
        x_1 = torch.tensor(x_1)
        x_2 = torch.tensor(x_2)
        y = torch.tensor(y, dtype=torch.float)
        optimizer.zero_grad()
        x_all = (x_0.float(), x_1.float(), x_2.float())
        laplacian_all = (laplacian_0, laplacian_down_1, laplacian_up_1, laplacian_2)
        incidence_all = (incidence_1, incidence_2)

        y_hat = model(x_all, laplacian_all, incidence_all)

        # print(y_hat)
        loss = loss_fn(y_hat, y)

        epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for (
                x_0,
                x_1,
                x_2,
                incidence_1,
                incidence_2,
                laplacian_0,
                laplacian_down_1,
                laplacian_up_1,
                laplacian_2,
                y,
            ) in zip(
                x_0_test,
                x_1_test,
                x_2_test,
                incidence_1_test,
                incidence_2_test,
                laplacian_0_test,
                laplacian_down_1_test,
                laplacian_up_1_test,
                laplacian_2_test,
                y_test,
                strict=False,
            ):
                x_0 = torch.tensor(x_0)
                x_1 = torch.tensor(x_1)
                x_2 = torch.tensor(x_2)
                y = torch.tensor(y, dtype=torch.float)
                optimizer.zero_grad()
                x_all = (x_0.float(), x_1.float(), x_2.float())
                laplacian_all = (
                    laplacian_0,
                    laplacian_down_1,
                    laplacian_up_1,
                    laplacian_2,
                )
                incidence_all = (incidence_1, incidence_2)

                y_hat = model(x_all, laplacian_all, incidence_all)

                loss = loss_fn(y_hat, y)
            print(f"Test_loss: {loss:.4f}", flush=True)
/Users/gbg141/Documents/TopoProjectX/TopoModelX/venv_modelx/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
Epoch: 1 loss: 399008.0930
Test_loss: 926.6080
Epoch: 2 loss: 477.8325
Test_loss: 204.4115
Epoch: 3 loss: 299.4982
Test_loss: 243.1712
Epoch: 4 loss: 202.5915
Test_loss: 302.4839
Epoch: 5 loss: 147.9002
Test_loss: 311.9497

2. Node Classification#

 In [26]:
import toponetx.datasets.graph as graph

%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

Pre-processing#

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

 In [27]:
dataset = graph.karate_club(complex_type="simplicial")
print(dataset)
max_rank = dataset.dim
print(max_rank)
Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4
4

Define Neighborhood Strctures#

Get incidence matrices \(\mathbf{B}_1,\mathbf{B}_2\) and Hodge Laplacians \(\mathbf{L}_0, \mathbf{L}_1\) and \(\mathbf{L}_2\).

Note that the original paper considered the weighted versions of these operators. However, the current TOPONETX package does not provide such feature yet.

 In [28]:
incidence_1 = dataset.incidence_matrix(rank=1)
incidence_2 = dataset.incidence_matrix(rank=2)

print(f"The incidence matrix B1 has shape: {incidence_1.shape}.")
print(f"The incidence matrix B2 has shape: {incidence_2.shape}.")
The incidence matrix B1 has shape: (34, 78).
The incidence matrix B2 has shape: (78, 45).
 In [29]:
laplacian_0 = dataset.hodge_laplacian_matrix(rank=0)
laplacian_down_1 = dataset.down_laplacian_matrix(rank=1)
laplacian_up_1 = dataset.up_laplacian_matrix(rank=1)
laplacian_down_2 = dataset.down_laplacian_matrix(rank=2)
laplacian_up_2 = dataset.up_laplacian_matrix(rank=2)
 In [30]:
laplacian_0 = from_sparse(laplacian_0)
laplacian_down_1 = from_sparse(laplacian_down_1)
laplacian_up_1 = from_sparse(laplacian_up_1)
laplacian_down_2 = from_sparse(laplacian_down_2)
laplacian_up_2 = from_sparse(laplacian_up_2)

incidence_1 = from_sparse(incidence_1)
incidence_2 = from_sparse(incidence_2)

Import signal#

We retrieve an input signal on the nodes, edges and faces. The signal will have shape \(n_\text{simplicial} \times\) in_channels, where in_channels is the dimension of each simplicial’s feature. Here, we have in_channels = channels_nodes $ = 2$.

 In [31]:
"""A function to obtain features based on the input: rank
"""


def get_simplicial_features(dataset, rank):
    if rank == 0:
        which_feat = "node_feat"
    elif rank == 1:
        which_feat = "edge_feat"
    elif rank == 2:
        which_feat = "face_feat"
    else:
        raise ValueError(
            "input dimension must be 0, 1 or 2, because features are supported on nodes, edges and faces"
        )

    x = list(dataset.get_simplex_attributes(which_feat).values())
    return torch.tensor(np.stack(x))
 In [32]:
x_0 = get_simplicial_features(dataset, rank=0)
x_1 = get_simplicial_features(dataset, rank=1)
x_2 = get_simplicial_features(dataset, rank=2)
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")
There are 34 nodes with features of dimension 2.
There are 78 edges with features of dimension 2.
There are 45 faces with features of dimension 2.

Define binary labels#

We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We convert the binary labels into one-hot encoder form, and keep the first four nodes’ true labels for the purpose of testing.

 In [34]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_train = y_true[:30]
y_test = y_true[-4:]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

Create and Train the Neural Network#

We specify the model with our pre-made neighborhood structures and specify an optimizer.

 In [45]:
class Network(torch.nn.Module):
    def __init__(
        self,
        in_channels_all,
        hidden_channels_all,
        out_channels,
        conv_order,
        max_rank,
        update_func=None,
        n_layers=2,
    ):
        super().__init__()
        self.base_model = SCCNN(
            in_channels_all=in_channels_all,
            hidden_channels_all=hidden_channels_all,
            conv_order=conv_order,
            sc_order=max_rank,
            update_func=update_func,
            n_layers=n_layers,
        )
        out_channels_0, _, _ = hidden_channels_all
        self.out_linear_0 = torch.nn.Linear(out_channels_0, out_channels)

    def forward(self, x_all, laplacian_all, incidence_all):
        x_all = self.base_model(x_all, laplacian_all, incidence_all)
        x_0, _, _ = x_all

        """
        We pass the output on the nodes to a linear layer and use that to generate a probability label for nodes
        """
        x_0, _, _ = x_all
        logits = self.out_linear_0(x_0)

        return torch.sigmoid(logits)
 In [46]:
"""Obtain the initial features on all simplices"""
x_all = (x_0, x_1, x_2)

conv_order = 2
in_channels_all = (x_0.shape[-1], x_1.shape[-1], x_2.shape[-1])
intermediate_channels_all = (16, 16, 16)
num_layers = 2
out_channels = 2  # num classes

laplacian_all = (
    laplacian_0,
    laplacian_down_1,
    laplacian_up_1,
    laplacian_down_2,
    laplacian_up_2,
)

incidence_all = (incidence_1, incidence_2)

model = Network(
    in_channels_all=in_channels_all,
    hidden_channels_all=intermediate_channels_all,
    out_channels=out_channels,
    conv_order=conv_order,
    max_rank=max_rank,
    update_func="sigmoid",
    n_layers=num_layers,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
 In [47]:
test_interval = 10
num_epochs = 100
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()
    y_hat = model(x_all, laplacian_all, incidence_all)
    y_hat = torch.softmax(y_hat, dim=1)
    loss = torch.nn.functional.binary_cross_entropy(
        y_hat[: len(y_train)].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred[: len(y_train)] == y_train).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_test = model(x_all, laplacian_all, incidence_all)
            # Projection to node-level
            y_hat_test = torch.softmax(y_hat_test, dim=1)
            y_pred_test = torch.where(
                y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
            )
            test_accuracy = (
                torch.eq(y_pred_test[-len(y_test) :], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)
Epoch: 1 loss: 0.6788 Train_acc: 0.4667
Epoch: 2 loss: 0.6248 Train_acc: 0.8000
Epoch: 3 loss: 0.6677 Train_acc: 0.7667
Epoch: 4 loss: 0.5903 Train_acc: 0.8000
Epoch: 5 loss: 0.5934 Train_acc: 0.8000
Epoch: 6 loss: 0.5666 Train_acc: 0.8000
Epoch: 7 loss: 0.5496 Train_acc: 0.8000
Epoch: 8 loss: 0.5381 Train_acc: 0.8000
Epoch: 9 loss: 0.5306 Train_acc: 0.8000
Epoch: 10 loss: 0.5257 Train_acc: 0.8000
Test_acc: 0.5000
Epoch: 11 loss: 0.5223 Train_acc: 0.8000
Epoch: 12 loss: 0.5198 Train_acc: 0.8000
Epoch: 13 loss: 0.5182 Train_acc: 0.8000
Epoch: 14 loss: 0.5170 Train_acc: 0.8000
Epoch: 15 loss: 0.5162 Train_acc: 0.8000
Epoch: 16 loss: 0.5156 Train_acc: 0.8000
Epoch: 17 loss: 0.5151 Train_acc: 0.8000
Epoch: 18 loss: 0.5147 Train_acc: 0.8000
Epoch: 19 loss: 0.5144 Train_acc: 0.8000
Epoch: 20 loss: 0.5142 Train_acc: 0.8000
Test_acc: 0.5000
Epoch: 21 loss: 0.5140 Train_acc: 0.8000
Epoch: 22 loss: 0.5139 Train_acc: 0.8000
Epoch: 23 loss: 0.5138 Train_acc: 0.8000
Epoch: 24 loss: 0.5137 Train_acc: 0.8000
Epoch: 25 loss: 0.5136 Train_acc: 0.8000
Epoch: 26 loss: 0.5135 Train_acc: 0.8000
Epoch: 27 loss: 0.5135 Train_acc: 0.8000
Epoch: 28 loss: 0.5134 Train_acc: 0.8000
Epoch: 29 loss: 0.5134 Train_acc: 0.8000
Epoch: 30 loss: 0.5133 Train_acc: 0.8000
Test_acc: 0.5000
Epoch: 31 loss: 0.5133 Train_acc: 0.8000
Epoch: 32 loss: 0.5133 Train_acc: 0.8000
Epoch: 33 loss: 0.5132 Train_acc: 0.8000
Epoch: 34 loss: 0.5132 Train_acc: 0.8000
Epoch: 35 loss: 0.5132 Train_acc: 0.8000
Epoch: 36 loss: 0.5132 Train_acc: 0.8000
Epoch: 37 loss: 0.5131 Train_acc: 0.8000
Epoch: 38 loss: 0.5131 Train_acc: 0.8000
Epoch: 39 loss: 0.5131 Train_acc: 0.8000
Epoch: 40 loss: 0.5131 Train_acc: 0.8000
Test_acc: 0.5000
Epoch: 41 loss: 0.5130 Train_acc: 0.8000
Epoch: 42 loss: 0.5130 Train_acc: 0.8000
Epoch: 43 loss: 0.5130 Train_acc: 0.8000
Epoch: 44 loss: 0.5130 Train_acc: 0.8000
Epoch: 45 loss: 0.5129 Train_acc: 0.8000
Epoch: 46 loss: 0.5129 Train_acc: 0.8000
Epoch: 47 loss: 0.5129 Train_acc: 0.8000
Epoch: 48 loss: 0.5128 Train_acc: 0.8000
Epoch: 49 loss: 0.5128 Train_acc: 0.8000
Epoch: 50 loss: 0.5128 Train_acc: 0.8000
Test_acc: 0.5000
Epoch: 51 loss: 0.5127 Train_acc: 0.8000
Epoch: 52 loss: 0.5127 Train_acc: 0.8000
Epoch: 53 loss: 0.5126 Train_acc: 0.8000
Epoch: 54 loss: 0.5126 Train_acc: 0.8000
Epoch: 55 loss: 0.5125 Train_acc: 0.8000
Epoch: 56 loss: 0.5124 Train_acc: 0.8000
Epoch: 57 loss: 0.5124 Train_acc: 0.8000
Epoch: 58 loss: 0.5123 Train_acc: 0.8000
Epoch: 59 loss: 0.5122 Train_acc: 0.8000
Epoch: 60 loss: 0.5121 Train_acc: 0.8000
Test_acc: 0.5000
Epoch: 61 loss: 0.5120 Train_acc: 0.8000
Epoch: 62 loss: 0.5119 Train_acc: 0.8000
Epoch: 63 loss: 0.5118 Train_acc: 0.8000
Epoch: 64 loss: 0.5116 Train_acc: 0.8000
Epoch: 65 loss: 0.5115 Train_acc: 0.8000
Epoch: 66 loss: 0.5113 Train_acc: 0.8000
Epoch: 67 loss: 0.5111 Train_acc: 0.8000
Epoch: 68 loss: 0.5109 Train_acc: 0.8000
Epoch: 69 loss: 0.5107 Train_acc: 0.8000
Epoch: 70 loss: 0.5105 Train_acc: 0.8000
Test_acc: 0.5000
Epoch: 71 loss: 0.5103 Train_acc: 0.8000
Epoch: 72 loss: 0.5101 Train_acc: 0.8000
Epoch: 73 loss: 0.5099 Train_acc: 0.8000
Epoch: 74 loss: 0.5098 Train_acc: 0.8000
Epoch: 75 loss: 0.5096 Train_acc: 0.8000
Epoch: 76 loss: 0.5094 Train_acc: 0.8000
Epoch: 77 loss: 0.5092 Train_acc: 0.8000
Epoch: 78 loss: 0.5089 Train_acc: 0.8000
Epoch: 79 loss: 0.5085 Train_acc: 0.8000
Epoch: 80 loss: 0.5082 Train_acc: 0.8000
Test_acc: 0.5000
Epoch: 81 loss: 0.5078 Train_acc: 0.8000
Epoch: 82 loss: 0.5075 Train_acc: 0.8000
Epoch: 83 loss: 0.5071 Train_acc: 0.8000
Epoch: 84 loss: 0.5068 Train_acc: 0.8000
Epoch: 85 loss: 0.5064 Train_acc: 0.8000
Epoch: 86 loss: 0.5060 Train_acc: 0.8000
Epoch: 87 loss: 0.5056 Train_acc: 0.8000
Epoch: 88 loss: 0.5051 Train_acc: 0.8000
Epoch: 89 loss: 0.5046 Train_acc: 0.8000
Epoch: 90 loss: 0.5041 Train_acc: 0.8000
Test_acc: 0.5000
Epoch: 91 loss: 0.5037 Train_acc: 0.8000
Epoch: 92 loss: 0.5032 Train_acc: 0.8000
Epoch: 93 loss: 0.5027 Train_acc: 0.8000
Epoch: 94 loss: 0.5022 Train_acc: 0.8000
Epoch: 95 loss: 0.5016 Train_acc: 0.8000
Epoch: 96 loss: 0.5011 Train_acc: 0.8000
Epoch: 97 loss: 0.5005 Train_acc: 0.8000
Epoch: 98 loss: 0.5000 Train_acc: 0.8000
Epoch: 99 loss: 0.4994 Train_acc: 0.8000
Epoch: 100 loss: 0.4989 Train_acc: 0.8000
Test_acc: 0.5000
 In [ ]: