Train a Simplicial Neural Network for Homology Localization (Dist2Cycle)#

In this notebook, we will create and train a Simplicial Neural Network for Homology Localization, as proposed in the paper by Alexandros D. Keros et. al : Dist2Cycle: A Simplicial Neural Network for Homology Localization(2022).

We train the model to perform binary node classification using the KarateClub benchmark dataset.

The equations of one layer of this neural network are given by:

🟥 \(\quad m^{(1 \rightarrow 1)}\_{y \rightarrow x} = (A \odot (I + L\downarrow)^+{xy}) \cdot h_{y}^{t,(1)}\cdot \Theta^t\)

🟧 \(\quad m_x^{(1 \rightarrow 1)} = \sum_{y \in \mathcal{L}\_\downarrow(x)} m_{y \rightarrow x}^{(1 \rightarrow 1)}\)

🟩 \(\quad m_x^{(1)} = m^{(1 \rightarrow 1)}_x\)

🟦 \(\quad h_x^{t+1,(1)} = \sigma(m_{x}^{(1)})\)

Where the notations are defined in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023).

[1]:
import numpy as np
import numpy.linalg as npla
import toponetx as tnx
import torch

from topomodelx.nn.simplicial.dist2cycle import Dist2Cycle
from topomodelx.utils.sparse import from_sparse

%load_ext autoreload
%autoreload 2

Pre-processing#

Import dataset#

The first step is to import the Karate Club (https://www.jstor.org/stable/3629752) dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.

We must first lift our graph dataset into the simplicial complex domain.

[2]:
dataset = tnx.datasets.karate_club(complex_type="simplicial")
print(dataset)
Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4

Define neighborhood structures.#

Now we retrieve the neighborhood structures (i.e. their representative matrices) that we will use to send messges on the domain. In this case, we need the boundary matrix (or incidence matrix) \(B_1\) and the adjacency matrix \(A_{\uparrow,0}\) on the nodes. For a santiy check, we show that the shape of the \(B_1 = n_\text{nodes} \times n_\text{edges}\) and \(A_{\uparrow,0} = n_\text{nodes} \times n_\text{nodes}\). We also convert the neighborhood structures to torch tensors.

[3]:
incidence_1 = dataset.incidence_matrix(rank=1)
adjacency_0 = dataset.adjacency_matrix(rank=0)

incidence_1 = from_sparse(incidence_1)
adjacency_0 = from_sparse(adjacency_0)

print(f"The incidence matrix B1 has shape: {incidence_1.shape}.")
print(f"The adjacency matrix A0 has shape: {adjacency_0.shape}.")
The incidence matrix B1 has shape: torch.Size([34, 78]).
The adjacency matrix A0 has shape: torch.Size([34, 34]).

Import signal#

Since our task will be node classification, we must retrieve an input signal on the nodes. The signal will have shape \(n_\text{nodes} \times\) in_channels, where in_channels is the dimension of each cell’s feature. Here, we have in_channels = channels_nodes $ = 34$. This is because the Karate dataset encodes the identity of each of the 34 nodes as a one hot encoder.

[4]:
x_0 = list(dataset.get_simplex_attributes("node_feat").values())
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
print(channels_nodes)
2
[5]:
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")
There are 34 nodes with features of dimension 2.

To load edge features, this is how we would do it (note that we will not use these features for this model, and this serves simply as a demonstration).

[6]:
x_1 = list(dataset.get_simplex_attributes("edge_feat").values())
x_1 = np.stack(x_1)
[7]:
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")
There are 78 edges with features of dimension 2.

Similarly for face features:

[8]:
x_2 = list(dataset.get_simplex_attributes("face_feat").values())
x_2 = np.stack(x_2)
[9]:
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")
There are 45 faces with features of dimension 2.

Define binary labels#

We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.

We convert the binary labels into one-hot encoder form, and keep the first four nodes’ true labels for the purpose of testing.

[10]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_train = y_true[:30]
y_test = y_true[-4:]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

print(y_train.shape)
torch.Size([30, 2])

Create Features#

[11]:
test = dataset.get_simplex_attributes("edge_feat")

ld = dataset.down_laplacian_matrix(rank=1).todense()
A = dataset.adjacency_matrix(rank=1).todense()
L_tilde_pinv = npla.pinv(ld + np.eye(ld.shape[0]))  # test inverse
channels_nodes = 78  # L_tilde_pinv.shape[-1]
print(channels_nodes)
print(np.array(A).shape)
print(np.array(ld).shape)
print(x_1.shape)  # edge features
print(L_tilde_pinv.shape)

adjacency = torch.from_numpy(A).float().to_sparse()
Linv = torch.from_numpy(L_tilde_pinv).float().to_sparse()

res = adjacency * Linv
print(res)
print(x_1)

x_1e = res.to_sparse()

incidence_0_1 = from_sparse(dataset.incidence_matrix(1))
78
(78, 78)
(78, 78)
(78, 2)
(78, 78)
tensor(indices=tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
                         1,  1,  1,  1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,
                         2,  2,  2,  2,  2,  2,  3,  3,  3,  3,  4,  4,  4,  4,
                         5,  5,  5,  5,  6,  6,  6,  6,  6,  6,  7,  7,  8,  8,
                         8,  8, 10, 10, 11, 11, 11, 11, 11, 11, 12, 12, 13, 13,
                        14, 14, 16, 16, 16, 16, 16, 16, 16, 16, 17, 17, 17, 17,
                        17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19,
                        19, 19, 20, 20, 21, 21, 22, 22, 24, 24, 24, 24, 24, 24,
                        24, 24, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 28, 28,
                        28, 28, 28, 28, 31, 31, 32, 32, 32, 32, 32, 32, 33, 33,
                        34, 34, 34, 34, 34, 34, 35, 35, 36, 36, 37, 37, 37, 37,
                        38, 38, 39, 39, 40, 40, 41, 41, 41, 41, 42, 42, 42, 42,
                        42, 42, 43, 43, 43, 43, 46, 46, 47, 47, 48, 48, 49, 49,
                        50, 50, 51, 51, 53, 53, 54, 54, 55, 55, 56, 56, 58, 58,
                        59, 59, 59, 59, 60, 60, 60, 60, 61, 61, 61, 61, 61, 61,
                        62, 62, 64, 64, 65, 65, 66, 66, 67, 67, 68, 68, 69, 69,
                        70, 70, 71, 71, 71, 71, 72, 72, 72, 72, 72, 72, 73, 73,
                        73, 73, 74, 74, 74, 74, 75, 75, 76, 76, 76, 76, 77, 77,
                        77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77, 77,
                        77, 77, 77, 77],
                       [ 1,  2,  6, 11, 12, 13, 14, 16, 17, 18, 19, 20, 21, 22,
                         0,  2,  6,  7, 11, 16, 24, 25, 26, 28,  0,  1,  6, 10,
                        11, 17, 24, 32, 33, 34,  5,  8, 35, 36,  5,  8, 37, 38,
                         3,  4, 35, 37,  0,  1,  2, 18, 25, 32,  1, 26,  3,  4,
                        36, 38,  2, 33,  0,  1,  2, 19, 28, 34,  0, 20,  0, 21,
                         0, 22,  0,  1, 17, 18, 19, 24, 25, 28,  0,  2, 16, 18,
                        19, 24, 32, 34,  0,  6, 16, 17, 25, 32,  0, 11, 16, 17,
                        28, 34,  0, 12,  0, 13,  0, 14,  1,  2, 16, 17, 25, 28,
                        32, 34,  1,  6, 16, 18, 24, 32,  1,  7, 31, 42,  1, 11,
                        16, 19, 24, 34, 26, 42,  2,  6, 17, 18, 24, 25,  2, 10,
                         2, 11, 17, 19, 24, 28,  3,  5,  3,  8,  4,  5, 39, 40,
                         4,  8, 37, 40, 37, 39, 42, 43, 73, 74, 26, 31, 41, 43,
                        73, 77, 41, 42, 74, 77, 47, 77, 46, 77, 49, 77, 48, 77,
                        51, 77, 50, 77, 54, 77, 53, 77, 56, 77, 55, 77, 61, 68,
                        60, 61, 71, 72, 59, 61, 71, 77, 58, 59, 60, 68, 72, 77,
                        64, 65, 62, 65, 62, 64, 67, 72, 66, 72, 58, 61, 70, 76,
                        69, 76, 59, 60, 72, 77, 59, 61, 66, 67, 71, 77, 41, 42,
                        74, 77, 41, 43, 73, 77, 76, 77, 69, 70, 75, 77, 42, 43,
                        46, 47, 48, 49, 50, 51, 53, 54, 55, 56, 60, 61, 71, 72,
                        73, 74, 75, 76]]),
       values=tensor([-0.0650, -0.0663, -0.0701, -0.0685, -0.0731, -0.0699,
                      -0.0731,  0.0971,  0.0959,  0.0920,  0.0937,  0.0891,
                       0.0923,  0.0891, -0.0650, -0.0687, -0.0721, -0.0754,
                      -0.0725, -0.0968,  0.0930,  0.0897,  0.0864,  0.0892,
                      -0.0663, -0.0687, -0.0789, -0.0864, -0.0749, -0.1380,
                      -0.1356,  0.1255,  0.1179,  0.1294, -0.1009, -0.1080,
                       0.2140,  0.2069, -0.1106, -0.1009,  0.1724,  0.1821,
                      -0.1009, -0.1106, -0.1821, -0.1724, -0.0701, -0.0721,
                      -0.0789, -0.1851, -0.1831, -0.1763, -0.0754, -0.1628,
                      -0.1080, -0.1009, -0.2069, -0.2140, -0.0864, -0.2927,
                      -0.0685, -0.0725, -0.0749, -0.1560, -0.1519, -0.1495,
                      -0.0731, -0.3018, -0.0699, -0.2309, -0.0731, -0.3018,
                       0.0971, -0.0968, -0.0996, -0.0992, -0.1012,  0.0943,
                       0.0948,  0.0927,  0.0959, -0.1380, -0.0996, -0.1046,
                      -0.1023, -0.1343,  0.1293,  0.1316,  0.0920, -0.1851,
                      -0.0992, -0.1046, -0.1780, -0.1725,  0.0937, -0.1560,
                      -0.1012, -0.1023, -0.1484, -0.1473,  0.0891, -0.3018,
                       0.0923, -0.2309,  0.0891, -0.3018,  0.0930, -0.1356,
                       0.0943, -0.1343, -0.0998, -0.0954,  0.1288,  0.1332,
                       0.0897, -0.1831,  0.0948, -0.1780, -0.0998, -0.1730,
                       0.0864, -0.1628, -0.1006,  0.1487,  0.0892, -0.1519,
                       0.0927, -0.1484, -0.0954, -0.1457, -0.1006, -0.0891,
                       0.1255, -0.1763,  0.1293, -0.1725,  0.1288, -0.1730,
                       0.1179, -0.2927,  0.1294, -0.1495,  0.1316, -0.1473,
                       0.1332, -0.1457,  0.2140, -0.1821,  0.2069, -0.2069,
                       0.1724, -0.1724, -0.1724,  0.1724,  0.1821, -0.2140,
                      -0.1724, -0.2678,  0.1724, -0.2678, -0.1429, -0.1426,
                       0.1734,  0.1738,  0.1487, -0.0891, -0.1429, -0.1654,
                      -0.0948,  0.0724, -0.1426, -0.1654, -0.0804, -0.0576,
                      -0.3090,  0.0693, -0.3090, -0.0607, -0.3090,  0.0693,
                      -0.3090, -0.0607, -0.3090,  0.0693, -0.3090, -0.0607,
                      -0.3090,  0.0693, -0.3090, -0.0607, -0.3090,  0.0693,
                      -0.3090, -0.0607, -0.1474,  0.1836, -0.1511, -0.1518,
                       0.1812,  0.1804, -0.1511, -0.1724, -0.0960,  0.0746,
                      -0.1474, -0.1518, -0.1724, -0.0803, -0.0759, -0.0553,
                      -0.2077,  0.2056, -0.2077, -0.1475,  0.2056, -0.1475,
                      -0.2813,  0.1545, -0.2813, -0.1019,  0.1836, -0.0803,
                      -0.2206,  0.1293, -0.2206, -0.0819,  0.1812, -0.0960,
                      -0.2017,  0.0754,  0.1804, -0.0759,  0.1545, -0.1019,
                      -0.2017, -0.0546,  0.1734, -0.0948, -0.1962,  0.0720,
                       0.1738, -0.0804, -0.1962, -0.0580, -0.1575,  0.0762,
                       0.1293, -0.0819, -0.1575, -0.0537,  0.0724, -0.0576,
                       0.0693, -0.0607,  0.0693, -0.0607,  0.0693, -0.0607,
                       0.0693, -0.0607,  0.0693, -0.0607,  0.0746, -0.0553,
                       0.0754, -0.0546,  0.0720, -0.0580,  0.0762, -0.0537]),
       size=(78, 78), nnz=270, layout=torch.sparse_coo)
[[ 2.40770523e-02  6.11494370e-02]
 [-4.97384816e-02  9.02294368e-02]
 [-1.64641943e-02  4.86156419e-02]
 [-4.83606968e-08 -1.49915427e-01]
 [-5.00650188e-08 -1.83463633e-01]
 [-5.75156136e-08 -1.83463678e-01]
 [-1.05313901e-02  4.90156896e-02]
 [-1.20998740e-01  8.07190537e-02]
 [-4.09101162e-08 -1.49915427e-01]
 [ 1.62092721e-08 -4.98118326e-02]
 [-8.23208503e-03  7.53657054e-03]
 [-5.43603450e-02  7.50192329e-02]
 [ 1.20385336e-02  1.53667741e-02]
 [ 1.72664091e-01  8.49207789e-02]
 [ 1.20385485e-02  1.53667741e-02]
 [ 3.95071507e-02  2.25182593e-01]
 [-7.38155320e-02  2.95849722e-02]
 [-4.05412391e-02 -1.36992298e-02]
 [-3.46084312e-02 -1.20638106e-02]
 [-7.84373805e-02  1.39397243e-02]
 [-1.20385038e-02 -3.50549929e-02]
 [ 1.48587048e-01  3.44989747e-02]
 [-1.20385019e-02 -3.50550078e-02]
 [ 1.26969576e-01  7.39617124e-02]
 [ 3.32742967e-02 -4.32841964e-02]
 [ 3.92070971e-02 -4.09105867e-02]
 [-7.12601468e-02  1.28915999e-02]
 [ 1.50366500e-01  3.93094122e-02]
 [-4.62186569e-03 -1.49070593e-02]
 [-3.30901868e-03  5.55232428e-02]
 [-1.51894227e-01  2.22967193e-02]
 [-1.15316376e-01  6.22131526e-02]
 [ 5.93280513e-03  1.29361625e-03]
 [ 8.23212508e-03 -3.29043306e-02]
 [-3.78961600e-02  2.72971559e-02]
 [-2.76038641e-08 -3.18312794e-02]
 [ 1.32374076e-08  5.55317712e-08]
 [ 9.97262095e-09  3.42880568e-08]
 [-1.96564454e-09  3.18312198e-02]
 [-2.07279651e-08 -8.27834606e-02]
 [-4.21484074e-08 -8.27835426e-02]
 [-7.98071399e-02  5.66642778e-03]
 [-4.40563001e-02  4.23841551e-02]
 [-6.83955252e-02  2.74787247e-02]
 [ 1.50366411e-01  2.16282904e-03]
 [-1.75315768e-01  8.67087543e-02]
 [ 1.21695530e-02 -2.16733683e-02]
 [-1.21695669e-02 -2.97473371e-02]
 [ 1.21695530e-02 -2.16734018e-02]
 [-1.21695707e-02 -2.97473222e-02]
 [ 1.21695530e-02 -2.16733590e-02]
 [-1.21695707e-02 -2.97473371e-02]
 [ 3.21251035e-01  1.11335017e-01]
 [ 1.21695511e-02 -2.16733608e-02]
 [-1.21695651e-02 -2.97473371e-02]
 [ 1.21695530e-02 -2.16733944e-02]
 [-1.21695725e-02 -2.97473520e-02]
 [ 5.66296689e-02 -1.70098506e-02]
 [ 8.54005478e-03 -1.45874349e-02]
 [-1.84309315e-02  1.33008221e-02]
 [-1.11998310e-02 -1.11719538e-02]
 [-3.55389528e-02 -2.02689916e-02]
 [-2.43989541e-03  7.58042466e-03]
 [-4.93099503e-02 -2.70438008e-03]
 [ 5.17499000e-02 -4.70586307e-02]
 [ 5.41898049e-02 -5.46390563e-02]
 [ 8.55402555e-03 -1.45243965e-02]
 [-8.55398457e-03 -4.46499884e-02]
 [-4.40789945e-02 -3.23143601e-03]
 [-8.49792436e-02 -2.31365804e-02]
 [-6.69150501e-02  1.28155053e-02]
 [ 7.23110419e-03 -2.44727693e-02]
 [-1.71080101e-02 -3.37488726e-02]
 [ 3.57507989e-02  3.30902189e-02]
 [ 1.14116156e-02  2.18122974e-02]
 [ 4.24033478e-02  4.10169996e-02]
 [ 1.80642232e-02  3.59393880e-02]
 [-2.43391562e-02 -9.89420712e-03]]
/Users/gbg141/Documents/TopoProjectX/TopoModelX/venv_modelx/lib/python3.11/site-packages/scipy/sparse/_index.py:143: SparseEfficiencyWarning: Changing the sparsity structure of a csr_matrix is expensive. lil_matrix is more efficient.
  self._set_arrayXarray(i, j, x)

Create the Neural Network#

Using the SAN class, we create our neural network with stacked layers. Given the considered dataset and task (Karate Club, node classification), a linear layer at the end produces an output with shape \(n_\text{nodes} \times 2\), so we can compare with our binary labels.

[15]:
class Network(torch.nn.Module):
    def __init__(self, channels, out_channels, n_layers=2):
        super().__init__()
        self.base_model = Dist2Cycle(
            channels=channels,
            n_layers=n_layers,
        )
        self.linear = torch.nn.Linear(channels, out_channels)

    def forward(self, x, incidence_1, adjacency_0):
        x = self.base_model(x, incidence_1, adjacency_0)
        x = self.linear(x)
        return torch.softmax(x, dim=1)
[22]:
out_channels = 2
n_layers = 1

model = Network(
    channels=channels_nodes,
    out_channels=out_channels,
    n_layers=3,
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

Train the Neural Network#

We specify the model with our pre-made neighborhood structures and specify an optimizer.

The following cell performs the training, looping over the network for a low number of epochs.

[23]:
test_interval = 10
num_epochs = 100
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat_edge = model(x_1e, Linv, adjacency)
    # We project the edge-level output of the model to the node-level
    # and apply softmax fn to get the final node-level classification output
    y_hat = torch.softmax(torch.sparse.mm(incidence_0_1, y_hat_edge), dim=1)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[: len(y_train)].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred[: len(y_train)] == y_train).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_edge_test = model(x_1e, Linv, adjacency)
            # Projection to node-level
            y_hat_test = torch.softmax(
                torch.sparse.mm(incidence_0_1, y_hat_edge_test), dim=1
            )
            y_pred_test = torch.where(
                y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
            )
            # _pred_test = torch.softmax(y_hat_test,dim=1).ge(0.5).float()
            test_accuracy = (
                torch.eq(y_pred_test[-len(y_test) :], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)
Epoch: 1 loss: 0.7234 Train_acc: 0.6333
Epoch: 2 loss: 0.6930 Train_acc: 0.6667
Epoch: 3 loss: 0.6769 Train_acc: 0.7000
Epoch: 4 loss: 0.6718 Train_acc: 0.7000
Epoch: 5 loss: 0.6701 Train_acc: 0.7000
Epoch: 6 loss: 0.6694 Train_acc: 0.7000
Epoch: 7 loss: 0.6691 Train_acc: 0.7000
Epoch: 8 loss: 0.6689 Train_acc: 0.7000
Epoch: 9 loss: 0.6689 Train_acc: 0.7000
Epoch: 10 loss: 0.6688 Train_acc: 0.7000
Test_acc: 0.2500
Epoch: 11 loss: 0.6688 Train_acc: 0.7000
Epoch: 12 loss: 0.6688 Train_acc: 0.7000
Epoch: 13 loss: 0.6688 Train_acc: 0.7000
Epoch: 14 loss: 0.6688 Train_acc: 0.7000
Epoch: 15 loss: 0.6688 Train_acc: 0.7000
Epoch: 16 loss: 0.6688 Train_acc: 0.7000
Epoch: 17 loss: 0.6688 Train_acc: 0.7000
Epoch: 18 loss: 0.6688 Train_acc: 0.7000
Epoch: 19 loss: 0.6688 Train_acc: 0.7000
Epoch: 20 loss: 0.6688 Train_acc: 0.7000
Test_acc: 0.2500
Epoch: 21 loss: 0.6688 Train_acc: 0.7000
Epoch: 22 loss: 0.6688 Train_acc: 0.7000
Epoch: 23 loss: 0.6688 Train_acc: 0.7000
Epoch: 24 loss: 0.6688 Train_acc: 0.7000
Epoch: 25 loss: 0.6688 Train_acc: 0.7000
Epoch: 26 loss: 0.6688 Train_acc: 0.7000
Epoch: 27 loss: 0.6688 Train_acc: 0.7000
Epoch: 28 loss: 0.6688 Train_acc: 0.7000
Epoch: 29 loss: 0.6688 Train_acc: 0.7000
Epoch: 30 loss: 0.6688 Train_acc: 0.7000
Test_acc: 0.2500
Epoch: 31 loss: 0.6688 Train_acc: 0.7000
Epoch: 32 loss: 0.6688 Train_acc: 0.7000
Epoch: 33 loss: 0.6688 Train_acc: 0.7000
Epoch: 34 loss: 0.6688 Train_acc: 0.7000
Epoch: 35 loss: 0.6688 Train_acc: 0.7000
Epoch: 36 loss: 0.6688 Train_acc: 0.7000
Epoch: 37 loss: 0.6688 Train_acc: 0.7000
Epoch: 38 loss: 0.6688 Train_acc: 0.7000
Epoch: 39 loss: 0.6688 Train_acc: 0.7000
Epoch: 40 loss: 0.6688 Train_acc: 0.7000
Test_acc: 0.2500
Epoch: 41 loss: 0.6688 Train_acc: 0.7000
Epoch: 42 loss: 0.6688 Train_acc: 0.7000
Epoch: 43 loss: 0.6688 Train_acc: 0.7000
Epoch: 44 loss: 0.6688 Train_acc: 0.7000
Epoch: 45 loss: 0.6688 Train_acc: 0.7000
Epoch: 46 loss: 0.6688 Train_acc: 0.7000
Epoch: 47 loss: 0.6688 Train_acc: 0.7000
Epoch: 48 loss: 0.6688 Train_acc: 0.7000
Epoch: 49 loss: 0.6688 Train_acc: 0.7000
Epoch: 50 loss: 0.6688 Train_acc: 0.7000
Test_acc: 0.2500
Epoch: 51 loss: 0.6688 Train_acc: 0.7000
Epoch: 52 loss: 0.6688 Train_acc: 0.7000
Epoch: 53 loss: 0.6688 Train_acc: 0.7000
Epoch: 54 loss: 0.6688 Train_acc: 0.7000
Epoch: 55 loss: 0.6688 Train_acc: 0.7000
Epoch: 56 loss: 0.6688 Train_acc: 0.7000
Epoch: 57 loss: 0.6688 Train_acc: 0.7000
Epoch: 58 loss: 0.6688 Train_acc: 0.7000
Epoch: 59 loss: 0.6688 Train_acc: 0.7000
Epoch: 60 loss: 0.6688 Train_acc: 0.7000
Test_acc: 0.2500
Epoch: 61 loss: 0.6688 Train_acc: 0.7000
Epoch: 62 loss: 0.6688 Train_acc: 0.7000
Epoch: 63 loss: 0.6688 Train_acc: 0.7000
Epoch: 64 loss: 0.6688 Train_acc: 0.7000
Epoch: 65 loss: 0.6688 Train_acc: 0.7000
Epoch: 66 loss: 0.6688 Train_acc: 0.7000
Epoch: 67 loss: 0.6688 Train_acc: 0.7000
Epoch: 68 loss: 0.6688 Train_acc: 0.7000
Epoch: 69 loss: 0.6688 Train_acc: 0.7000
Epoch: 70 loss: 0.6688 Train_acc: 0.7000
Test_acc: 0.2500
Epoch: 71 loss: 0.6688 Train_acc: 0.7000
Epoch: 72 loss: 0.6688 Train_acc: 0.7000
Epoch: 73 loss: 0.6688 Train_acc: 0.7000
Epoch: 74 loss: 0.6688 Train_acc: 0.7000
Epoch: 75 loss: 0.6688 Train_acc: 0.7000
Epoch: 76 loss: 0.6688 Train_acc: 0.7000
Epoch: 77 loss: 0.6688 Train_acc: 0.7000
Epoch: 78 loss: 0.6688 Train_acc: 0.7000
Epoch: 79 loss: 0.6688 Train_acc: 0.7000
Epoch: 80 loss: 0.6688 Train_acc: 0.7000
Test_acc: 0.2500
Epoch: 81 loss: 0.6688 Train_acc: 0.7000
Epoch: 82 loss: 0.6688 Train_acc: 0.7000
Epoch: 83 loss: 0.6688 Train_acc: 0.7000
Epoch: 84 loss: 0.6688 Train_acc: 0.7000
Epoch: 85 loss: 0.6688 Train_acc: 0.7000
Epoch: 86 loss: 0.6688 Train_acc: 0.7000
Epoch: 87 loss: 0.6688 Train_acc: 0.7000
Epoch: 88 loss: 0.6688 Train_acc: 0.7000
Epoch: 89 loss: 0.6688 Train_acc: 0.7000
Epoch: 90 loss: 0.6688 Train_acc: 0.7000
Test_acc: 0.2500
Epoch: 91 loss: 0.6688 Train_acc: 0.7000
Epoch: 92 loss: 0.6688 Train_acc: 0.7000
Epoch: 93 loss: 0.6688 Train_acc: 0.7000
Epoch: 94 loss: 0.6688 Train_acc: 0.7000
Epoch: 95 loss: 0.6688 Train_acc: 0.7000
Epoch: 96 loss: 0.6688 Train_acc: 0.7000
Epoch: 97 loss: 0.6688 Train_acc: 0.7000
Epoch: 98 loss: 0.6688 Train_acc: 0.7000
Epoch: 99 loss: 0.6688 Train_acc: 0.7000
Epoch: 100 loss: 0.6688 Train_acc: 0.7000
Test_acc: 0.2500
[ ]: