Train a Simplicial Convolutional Neural Network (SCNN)#

In this notebook, we will create and train a convolutional neural network in the simplicial complex domain, as proposed in the paper by Yang et. al : SIMPLICIAL CONVOLUTIONAL NEURAL NETWORKS (2022).

We train the model to perform:#

1.  Complex classification using the shrec16 benchmark dataset.
2.  Node classification using the karate dataset

Simplicial Convolutional Neural Networks [SCNN]#

At layer \(t\), given the input simplicial (edge) feature matrix \(\mathbf{H}_t\), the SCNN layer is defined as

\[\mathbf{H}_{t+1} = \sigma \Bigg[ \mathbf{H}_t\mathbf{\Theta}_t + \sum_{p_d=1}^{P_d}(\mathbf{L}_{\downarrow,1})^{p_d}\mathbf{H}_t\mathbf{\Theta}_{t,p_d} + \sum_{p_u=1}^{P_u}(\mathbf{L}_{\uparrow,1})^{p_u}\mathbf{H}_{t}\mathbf{\Theta}_{t,p_u} \Bigg]\]

where \(p_d\) and \(p_u\) are the lower and upper convolution orders, respectively, and \(\mathbf{\Theta}_{t,p_d}\) and \(\mathbf{\Theta}_{t,p_u}\) are the learnable weights. One can use \((\mathbf{L}_{\uparrow,1})^{p_u}\) and \((\mathbf{L}_{\uparrow,1})^{p_d}\) to perform higher-order upper and lower convolutions.

To align with the notations in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023), we can use the following to denote the above layer definition.

🟥 $:nbsphinx-math:quad `m_{y :nbsphinx-math:rightarrow {z} :nbsphinx-math:rightarrow x}^{p_u(1 :nbsphinx-math:rightarrow 2` \rightarrow 1)} = ((L_{\uparrow,1})^{p_u})_{xy} \cdot `h_y^{t,(1)} :nbsphinx-math:cdot :nbsphinx-math:theta`^{t, p_u} $ ——– Aggregate from \(p_u\)-hop upper neighbor \(y\) to \(x\)

🟥 $:nbsphinx-math:quad `m_{y :nbsphinx-math:rightarrow {z} :nbsphinx-math:rightarrow x}^{p_d(1 :nbsphinx-math:rightarrow 0` \rightarrow 1)} = ((L_{\downarrow,1})^{p_d})_{xy} \cdot `h_y^{t,(1)} :nbsphinx-math:cdot :nbsphinx-math:theta`^{t, p_d} $ ——– Aggregate from \(p_d\)-hop lower neighbor \(y\) to \(x\)

🟥 \(\quad m^{(1 \rightarrow 1)}_{x \rightarrow x} = \theta^t \cdot h_x^{t, (1)}\) ——– Aggregate from \(x\) itself

🟧 \(\quad m_{x}^{p_u,(1 \rightarrow 2 \rightarrow 1)} = \sum_{y \in \mathcal{L}_\uparrow(x)}m_{y \rightarrow \{z\} \rightarrow x}^{p_u,(1 \rightarrow 2 \rightarrow 1)}\) ——– Collect the aggregated information from the upper neighborhood

🟧 \(\quad m_{x}^{p_d,(1 \rightarrow 0 \rightarrow 1)} = \sum_{y \in \mathcal{L}_\downarrow(x)}m_{y \rightarrow \{z\} \rightarrow x}^{p_d,(1 \rightarrow 0 \rightarrow 1)}\) ——– Collect the aggregated information from the lower neighborhood

🟧 \(\quad m^{(1 \rightarrow 1)}_{x} = m^{(1 \rightarrow 1)}_{x \rightarrow x}\)

🟩 \(\quad m_x^{(1)} = m_x^{(1 \rightarrow 1)} + \sum_{p_u=1}^{P_u} m_{x}^{p_u,(1 \rightarrow 2 \rightarrow 1)} + \sum_{p_d=1}^{P_d} m_{x}^{p_d,(1 \rightarrow 0 \rightarrow 1)}\) ——– Collect all the aggregated information

🟦 \(\quad h_x^{t+1, (1)} = \sigma(m_x^{(1)})\) ——– Pass through the nonlinearity

1. Complex Classification#

[1]:
import numpy as np
import toponetx as tnx
import torch
from sklearn.model_selection import train_test_split

from topomodelx.nn.simplicial.scnn import SCNN
from topomodelx.utils.sparse import from_sparse

%load_ext autoreload
%autoreload 2

Pre-processing#

We must first lift our graph dataset into the simplicial complex domain.

[2]:
shrec, _ = tnx.datasets.shrec_16(size="small")

shrec = {key: np.array(value) for key, value in shrec.items()}

x_0s = shrec["node_feat"]
x_1s = shrec["edge_feat"]
x_2s = shrec["face_feat"]

ys = shrec["label"]
simplexes = shrec["complexes"]
Loading shrec 16 small dataset...

done!
[3]:
in_channels_0 = x_0s[-1].shape[1]
in_channels_1 = x_1s[-1].shape[1]
in_channels_2 = x_2s[-1].shape[1]

Define Neighborhood Strctures#

Get incidence matrices and Hodge Laplacians

[4]:
max_rank = 2  # the order of the SC is two
incidence_1_list = []
incidence_2_list = []

laplacian_0_list = []
laplacian_down_1_list = []
laplacian_up_1_list = []
laplacian_2_list = []

for simplex in simplexes:
    incidence_1 = simplex.incidence_matrix(rank=1)
    incidence_2 = simplex.incidence_matrix(rank=2)
    laplacian_0 = simplex.hodge_laplacian_matrix(rank=0)
    laplacian_down_1 = simplex.down_laplacian_matrix(rank=1)
    laplacian_up_1 = simplex.up_laplacian_matrix(rank=1)
    laplacian_2 = simplex.hodge_laplacian_matrix(rank=2)

    incidence_1 = from_sparse(incidence_1)
    incidence_2 = from_sparse(incidence_2)
    laplacian_0 = from_sparse(laplacian_0)
    laplacian_down_1 = from_sparse(laplacian_down_1)
    laplacian_up_1 = from_sparse(laplacian_up_1)
    laplacian_2 = from_sparse(laplacian_2)

    incidence_1_list.append(incidence_1)
    incidence_2_list.append(incidence_2)
    laplacian_0_list.append(laplacian_0)
    laplacian_down_1_list.append(laplacian_down_1)
    laplacian_up_1_list.append(laplacian_up_1)
    laplacian_2_list.append(laplacian_2)

Train the Neural Network#

We specify the model with our pre-made neighborhood structures and specify an optimizer.

[5]:
rank = 1  # simplex level
conv_order_down = 2
conv_order_up = 2
hidden_channels = 4
out_channels = 1  # num classes
num_layers = 2

# select the simplex level
if rank == 0:
    laplacian_down = None
    laplacian_up = laplacian_0_list  # the graph laplacian
    conv_order_down = 0
    x = x_0s
    in_channels = in_channels_0
elif rank == 1:
    laplacian_down = laplacian_down_1_list
    laplacian_up = laplacian_up_1_list
    x = x_1s
    in_channels = in_channels_1
elif rank == 2:
    laplacian_down = laplacian_2_list
    laplacian_up = None
    x = x_2s
    in_channels = in_channels_2
else:
    raise ValueError("Rank must be not larger than 2 on this dataset")
[6]:
class Network(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        conv_order_down,
        conv_order_up,
        n_layers=2,
    ):
        super().__init__()
        self.base_model = SCNN(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            conv_order_down=conv_order_down,
            conv_order_up=conv_order_up,
            n_layers=n_layers,
        )
        self.linear = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, laplacian_down, laplacian_up):
        x = self.base_model(x, laplacian_down, laplacian_up)
        x = self.linear(x)
        one_dimensional_cells_mean = torch.nanmean(x, dim=0)
        one_dimensional_cells_mean[torch.isnan(one_dimensional_cells_mean)] = 0
        return one_dimensional_cells_mean
[7]:
model = Network(
    in_channels=in_channels,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    conv_order_down=conv_order_down,
    conv_order_up=conv_order_up,
    n_layers=num_layers,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.MSELoss()
[8]:
test_size = 0.2
x_train, x_test = train_test_split(x, test_size=test_size, shuffle=False)

laplacian_down_train, laplacian_down_test = train_test_split(
    laplacian_down, test_size=test_size, shuffle=False
)
laplacian_up_train, laplacian_up_test = train_test_split(
    laplacian_up, test_size=test_size, shuffle=False
)
y_train, y_test = train_test_split(ys, test_size=test_size, shuffle=False)
[9]:
test_interval = 2
num_epochs = 10

# select which feature to use for labeling
simplex_order_select = 1

for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    for x, laplacian_down, laplacian_up, y in zip(
        x_train, laplacian_down_train, laplacian_up_train, y_train, strict=False
    ):
        x = torch.tensor(x, dtype=torch.float)
        y = torch.tensor(y, dtype=torch.float)
        optimizer.zero_grad()

        y_hat = model(x, laplacian_down, laplacian_up)

        # print(y_hat.shape)
        loss = loss_fn(y_hat, y)

        epoch_loss.append(loss.item())
        loss.backward()
        optimizer.step()

    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            for x, laplacian_down, laplacian_up, y in zip(
                x_test, laplacian_down_test, laplacian_up_test, y_test, strict=False
            ):
                x = torch.tensor(x, dtype=torch.float)
                y = torch.tensor(y, dtype=torch.float)
                optimizer.zero_grad()

                y_hat = model(x, laplacian_down, laplacian_up)

                loss = loss_fn(y_hat, y)
            print(f"Test_loss: {loss:.4f}", flush=True)
/Users/gbg141/Documents/TopoProjectX/TopoModelX/venv_modelx/lib/python3.11/site-packages/torch/nn/modules/loss.py:536: UserWarning: Using a target size (torch.Size([])) that is different to the input size (torch.Size([1])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
  return F.mse_loss(input, target, reduction=self.reduction)
Epoch: 1 loss: 888.6446
Epoch: 2 loss: 149.8124
Test_loss: 16.6203
Epoch: 3 loss: 158.0982
Epoch: 4 loss: 178.9690
Test_loss: 64.3029
Epoch: 5 loss: 172.9724
Epoch: 6 loss: 87.4266
Test_loss: 110.1082
Epoch: 7 loss: 91.9618
Epoch: 8 loss: 89.1734
Test_loss: 117.3142
Epoch: 9 loss: 87.6673
Epoch: 10 loss: 86.8797
Test_loss: 116.6862

2. Node Classification#

Import Karate dataset#

[11]:
dataset = tnx.datasets.karate_club(complex_type="simplicial")
print(dataset)

# Maximal simplex order
max_rank = dataset.dim
print("maximal simple order:", max_rank)
Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4
maximal simple order: 4

Define Neighborhood Strctures#

Get incidence matrices and Hodge Laplacians

[12]:
incidence_1 = dataset.incidence_matrix(rank=1)
incidence_1 = from_sparse(incidence_1)
incidence_2 = dataset.incidence_matrix(rank=2)
incidence_2 = from_sparse(incidence_2)
print(f"The incidence matrix B1 has shape: {incidence_1.shape}.")
print(f"The incidence matrix B2 has shape: {incidence_2.shape}.")
The incidence matrix B1 has shape: torch.Size([34, 78]).
The incidence matrix B2 has shape: torch.Size([78, 45]).

Weighted Hodge Laplacians#

In the original paper, the weighted versions of the Hodge Laplacians are used. However, the current TOPONETX package does not provide this weighting feature yet.

[13]:
laplacian_0 = dataset.hodge_laplacian_matrix(rank=0)
laplacian_down_1 = dataset.down_laplacian_matrix(rank=1)
laplacian_up_1 = dataset.up_laplacian_matrix(rank=1)
laplacian_down_2 = dataset.down_laplacian_matrix(rank=2)
laplacian_up_2 = dataset.up_laplacian_matrix(rank=2)

laplacian_0 = from_sparse(laplacian_0)
laplacian_down_1 = from_sparse(laplacian_down_1)
laplacian_up_1 = from_sparse(laplacian_up_1)
laplacian_down_2 = from_sparse(laplacian_down_2)
laplacian_up_2 = from_sparse(laplacian_up_2)

Import signals#

For example, performing learning on the edges, we use the input on edges \(\mathbf{x}_1\)

[14]:
x_0 = list(dataset.get_simplex_attributes("node_feat").values())
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
x_1 = list(dataset.get_simplex_attributes("edge_feat").values())
x_1 = np.stack(x_1)
chennel_edges = x_1.shape[-1]
x_2 = list(dataset.get_simplex_attributes("face_feat").values())
x_2 = np.stack(x_2)
channel_faces = x_2.shape[-1]
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")
There are 34 nodes with features of dimension 2.
There are 78 edges with features of dimension 2.
There are 45 faces with features of dimension 2.

Define a function to select the features on certain order of simplices

[15]:
"""A function to obtain features based on the input: rank
"""


def get_simplicial_features(dataset, rank):
    if rank == 0:
        which_feat = "node_feat"
    elif rank == 1:
        which_feat = "edge_feat"
    elif rank == 2:
        which_feat = "face_feat"
    else:
        raise ValueError(
            "input dimension must be 0, 1 or 2, because features are supported on nodes, edges and faces"
        )

    x = list(dataset.get_simplex_attributes(which_feat).values())
    return torch.tensor(np.stack(x))

Define binary labels and Prepare the training-testing split#

[16]:
y = np.array(
    [
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        1,
        0,
        1,
        1,
        1,
        1,
        0,
        0,
        1,
        1,
        0,
        1,
        0,
        1,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
        0,
    ]
)
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_train = y_true[:30]
y_test = y_true[-4:]

y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)

Create the SCNN for node classification#

Use the SCNNLayer classm we create a neural network with stacked layers, without aggregation.

[17]:
class Network(torch.nn.Module):
    def __init__(
        self,
        in_channels,
        hidden_channels,
        out_channels,
        conv_order_down,
        conv_order_up,
        n_layers=2,
    ):
        super().__init__()
        self.base_model = SCNN(
            in_channels=in_channels,
            hidden_channels=hidden_channels,
            conv_order_down=conv_order_down,
            conv_order_up=conv_order_up,
            n_layers=n_layers,
        )
        self.linear = torch.nn.Linear(hidden_channels, out_channels)

    def forward(self, x, laplacian_down, laplacian_up):
        x = self.base_model(x, laplacian_down, laplacian_up)
        return self.linear(x)
[18]:
"""
Select the simplex order, i.e., on which level of simplices the learning will be performed
"""
rank = 1  # simplex level
conv_order_down = 2
conv_order_up = 2
x = get_simplicial_features(dataset, rank)
channels_x = x.shape[-1]
if rank == 0:
    laplacian_down = None
    laplacian_up = laplacian_0  # the graph laplacian
    conv_order_down = 0
elif rank == 1:
    laplacian_down = laplacian_down_1
    laplacian_up = laplacian_up_1
elif rank == 2:
    laplacian_down = laplacian_down_2
    laplacian_up = laplacian_up_2
else:
    raise ValueError("Rank must be not larger than 2 on this dataset")

hidden_channels = 16
out_channels = 2  # num classes
num_layers = 1

model = Network(
    in_channels=channels_x,
    hidden_channels=hidden_channels,
    out_channels=out_channels,
    conv_order_down=conv_order_down,
    conv_order_up=conv_order_up,
    n_layers=num_layers,
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
print(model)
Network(
  (base_model): SCNN(
    (layers): ModuleList(
      (0): SCNNLayer()
    )
  )
  (linear): Linear(in_features=16, out_features=2, bias=True)
)

We add a final linear layer that produces an output with shape \(n_{\rm{nodes}}\times 2\), so we can compare with the binary labels.

Train the SCNN#

The following cell performs the training, looping over the network for a low number of epochs.

[19]:
test_interval = 10
num_epochs = 100
for epoch_i in range(1, num_epochs + 1):
    epoch_loss = []
    model.train()
    optimizer.zero_grad()

    y_hat_edge = model(x, laplacian_down, laplacian_up)
    # We project the edge-level output of the model to the node-level
    # and apply softmax fn to get the final node-level classification output
    y_hat = torch.softmax(torch.sparse.mm(incidence_1, y_hat_edge), dim=1)
    loss = torch.nn.functional.binary_cross_entropy_with_logits(
        y_hat[: len(y_train)].float(), y_train.float()
    )
    epoch_loss.append(loss.item())
    loss.backward()
    optimizer.step()

    y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
    accuracy = (y_pred[: len(y_train)] == y_train).all(dim=1).float().mean().item()
    print(
        f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
        flush=True,
    )
    if epoch_i % test_interval == 0:
        with torch.no_grad():
            y_hat_edge_test = model(x, laplacian_down, laplacian_up)
            # Projection to node-level
            y_hat_test = torch.softmax(
                torch.sparse.mm(incidence_1, y_hat_edge_test), dim=1
            )
            y_pred_test = torch.where(
                y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
            )
            test_accuracy = (
                torch.eq(y_pred_test[-len(y_test) :], y_test)
                .all(dim=1)
                .float()
                .mean()
                .item()
            )
            print(f"Test_acc: {test_accuracy:.4f}", flush=True)
Epoch: 1 loss: 0.7327 Train_acc: 0.3000
Epoch: 2 loss: 0.7171 Train_acc: 0.7333
Epoch: 3 loss: 0.6984 Train_acc: 0.7333
Epoch: 4 loss: 0.6773 Train_acc: 0.7333
Epoch: 5 loss: 0.6590 Train_acc: 0.7333
Epoch: 6 loss: 0.6431 Train_acc: 0.7667
Epoch: 7 loss: 0.6288 Train_acc: 0.7667
Epoch: 8 loss: 0.6184 Train_acc: 0.8000
Epoch: 9 loss: 0.6100 Train_acc: 0.8000
Epoch: 10 loss: 0.6023 Train_acc: 0.8333
Test_acc: 0.5000
Epoch: 11 loss: 0.5951 Train_acc: 0.8333
Epoch: 12 loss: 0.5880 Train_acc: 0.8333
Epoch: 13 loss: 0.5796 Train_acc: 0.8333
Epoch: 14 loss: 0.5721 Train_acc: 0.8667
Epoch: 15 loss: 0.5693 Train_acc: 0.9000
Epoch: 16 loss: 0.5686 Train_acc: 0.9000
Epoch: 17 loss: 0.5679 Train_acc: 0.9000
Epoch: 18 loss: 0.5669 Train_acc: 0.9000
Epoch: 19 loss: 0.5655 Train_acc: 0.8667
Epoch: 20 loss: 0.5639 Train_acc: 0.9000
Test_acc: 0.7500
Epoch: 21 loss: 0.5622 Train_acc: 0.9000
Epoch: 22 loss: 0.5605 Train_acc: 0.9000
Epoch: 23 loss: 0.5592 Train_acc: 0.9000
Epoch: 24 loss: 0.5590 Train_acc: 0.9000
Epoch: 25 loss: 0.5603 Train_acc: 0.9000
Epoch: 26 loss: 0.5607 Train_acc: 0.9000
Epoch: 27 loss: 0.5594 Train_acc: 0.9000
Epoch: 28 loss: 0.5582 Train_acc: 0.9000
Epoch: 29 loss: 0.5579 Train_acc: 0.9000
Epoch: 30 loss: 0.5582 Train_acc: 0.9000
Test_acc: 0.7500
Epoch: 31 loss: 0.5584 Train_acc: 0.9000
Epoch: 32 loss: 0.5584 Train_acc: 0.9000
Epoch: 33 loss: 0.5581 Train_acc: 0.9000
Epoch: 34 loss: 0.5576 Train_acc: 0.9000
Epoch: 35 loss: 0.5570 Train_acc: 0.9000
Epoch: 36 loss: 0.5565 Train_acc: 0.9000
Epoch: 37 loss: 0.5563 Train_acc: 0.9000
Epoch: 38 loss: 0.5564 Train_acc: 0.9000
Epoch: 39 loss: 0.5564 Train_acc: 0.9000
Epoch: 40 loss: 0.5562 Train_acc: 0.9000
Test_acc: 0.7500
Epoch: 41 loss: 0.5558 Train_acc: 0.9000
Epoch: 42 loss: 0.5555 Train_acc: 0.9000
Epoch: 43 loss: 0.5554 Train_acc: 0.9000
Epoch: 44 loss: 0.5553 Train_acc: 0.9000
Epoch: 45 loss: 0.5553 Train_acc: 0.9000
Epoch: 46 loss: 0.5553 Train_acc: 0.9000
Epoch: 47 loss: 0.5552 Train_acc: 0.9000
Epoch: 48 loss: 0.5551 Train_acc: 0.9000
Epoch: 49 loss: 0.5549 Train_acc: 0.9000
Epoch: 50 loss: 0.5548 Train_acc: 0.9000
Test_acc: 0.7500
Epoch: 51 loss: 0.5547 Train_acc: 0.9000
Epoch: 52 loss: 0.5547 Train_acc: 0.9000
Epoch: 53 loss: 0.5546 Train_acc: 0.9000
Epoch: 54 loss: 0.5546 Train_acc: 0.9000
Epoch: 55 loss: 0.5546 Train_acc: 0.9000
Epoch: 56 loss: 0.5545 Train_acc: 0.9000
Epoch: 57 loss: 0.5545 Train_acc: 0.9000
Epoch: 58 loss: 0.5544 Train_acc: 0.9000
Epoch: 59 loss: 0.5544 Train_acc: 0.9000
Epoch: 60 loss: 0.5544 Train_acc: 0.9000
Test_acc: 0.7500
Epoch: 61 loss: 0.5543 Train_acc: 0.9000
Epoch: 62 loss: 0.5543 Train_acc: 0.9000
Epoch: 63 loss: 0.5543 Train_acc: 0.9000
Epoch: 64 loss: 0.5543 Train_acc: 0.9000
Epoch: 65 loss: 0.5543 Train_acc: 0.9000
Epoch: 66 loss: 0.5542 Train_acc: 0.9000
Epoch: 67 loss: 0.5542 Train_acc: 0.9000
Epoch: 68 loss: 0.5542 Train_acc: 0.9000
Epoch: 69 loss: 0.5542 Train_acc: 0.9000
Epoch: 70 loss: 0.5541 Train_acc: 0.9000
Test_acc: 0.7500
Epoch: 71 loss: 0.5541 Train_acc: 0.9000
Epoch: 72 loss: 0.5541 Train_acc: 0.9000
Epoch: 73 loss: 0.5541 Train_acc: 0.9000
Epoch: 74 loss: 0.5541 Train_acc: 0.9000
Epoch: 75 loss: 0.5541 Train_acc: 0.9000
Epoch: 76 loss: 0.5541 Train_acc: 0.9000
Epoch: 77 loss: 0.5540 Train_acc: 0.9000
Epoch: 78 loss: 0.5540 Train_acc: 0.9000
Epoch: 79 loss: 0.5540 Train_acc: 0.9000
Epoch: 80 loss: 0.5540 Train_acc: 0.9000
Test_acc: 0.7500
Epoch: 81 loss: 0.5540 Train_acc: 0.9000
Epoch: 82 loss: 0.5540 Train_acc: 0.9000
Epoch: 83 loss: 0.5540 Train_acc: 0.9000
Epoch: 84 loss: 0.5539 Train_acc: 0.9000
Epoch: 85 loss: 0.5539 Train_acc: 0.9000
Epoch: 86 loss: 0.5539 Train_acc: 0.9000
Epoch: 87 loss: 0.5539 Train_acc: 0.9000
Epoch: 88 loss: 0.5539 Train_acc: 0.9000
Epoch: 89 loss: 0.5539 Train_acc: 0.9000
Epoch: 90 loss: 0.5539 Train_acc: 0.9000
Test_acc: 0.7500
Epoch: 91 loss: 0.5539 Train_acc: 0.9000
Epoch: 92 loss: 0.5539 Train_acc: 0.9000
Epoch: 93 loss: 0.5538 Train_acc: 0.9000
Epoch: 94 loss: 0.5538 Train_acc: 0.9000
Epoch: 95 loss: 0.5538 Train_acc: 0.9000
Epoch: 96 loss: 0.5538 Train_acc: 0.9000
Epoch: 97 loss: 0.5538 Train_acc: 0.9000
Epoch: 98 loss: 0.5538 Train_acc: 0.9000
Epoch: 99 loss: 0.5538 Train_acc: 0.9000
Epoch: 100 loss: 0.5538 Train_acc: 0.9000
Test_acc: 0.7500
[ ]: