Train a Simplicial 2-complex convolutional neural network (SCConv)#
In this notebook, we will create and train a Simplicial 2-complex convolutional neural in the simplicial complex domain, as proposed in the paper by Bunch et. al : Simplicial 2-Complex Convolutional Neural Networks (2020).
We train the model to perform
The equations of one layer of this neural network are given by:
π₯ \(\quad m_{y\rightarrow x}^{(0\rightarrow 0)} = ({\tilde{A}_{\uparrow,0}})_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0\rightarrow0)}\)
π₯ \(\quad m^{(1\rightarrow0)}_{y\rightarrow x} = (B_1)_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(1\rightarrow 0)}\)
π₯ \(\quad m^{(0 \rightarrow 1)}_{y \rightarrow x} = (\tilde B_1)_{xy} \cdot h_y^{t,(0)} \cdot \Theta^{t,(0 \rightarrow1)}\)
π₯ \(\quad m^{(1\rightarrow1)}_{y\rightarrow x} = ({\tilde{A}_{\downarrow,1}} + {\tilde{A}_{\uparrow,1}})_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,(1\rightarrow1)}\)
π₯ \(\quad m^{(2\rightarrow1)}_{y \rightarrow x} = (B_2)_{xy} \cdot h_y^{t,(2)} \cdot \Theta^{t,(2 \rightarrow1)}\)
π₯ \(\quad m^{(1 \rightarrow 2)}_{y \rightarrow x} = (\tilde B_2)_{xy} \cdot h_y^{t,(1)} \cdot \Theta^{t,(1 \rightarrow 2)}\)
π₯ \(\quad m^{(2 \rightarrow 2)}_{y \rightarrow x} = ({\tilde{A}_{\downarrow,2}})_{xy} \cdot h_y^{t,(2)} \cdot \Theta^{t,(2 \rightarrow 2)}\)
π§ \(\quad m_x^{(0 \rightarrow 0)} = \sum_{y \in \mathcal{L}_\uparrow(x)} m_{y \rightarrow x}^{(0 \rightarrow 0)}\)
π§ \(\quad m_x^{(1 \rightarrow 0)} = \sum_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(1 \rightarrow 0)}\)
π§ \(\quad m_x^{(0 \rightarrow 1)} = \sum_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(0 \rightarrow 1)}\)
π§ \(\quad m_x^{(1 \rightarrow 1)} = \sum_{y \in (\mathcal{L}_\uparrow(x) + \mathcal{L}_\downarrow(x))} m_{y \rightarrow x}^{(1 \rightarrow 1)}\)
π§ \(\quad m_x^{(2 \rightarrow 1)} = \sum_{y \in \mathcal{C}(x)} m_{y \rightarrow x}^{(2 \rightarrow 1)}\)
π§ \(\quad m_x^{(1 \rightarrow 2)} = \sum_{y \in \mathcal{B}(x)} m_{y \rightarrow x}^{(1 \rightarrow 2)}\)
π§ \(\quad m_x^{(2 \rightarrow 2)} = \sum_{y \in \mathcal{L}_\downarrow(x)} m_{y \rightarrow x}^{(2 \rightarrow 2)}\)
π© \(\quad m_x^{(0)} = m_x^{(1\rightarrow0)}+ m_x^{(0\rightarrow0)}\)
π© \(\quad m_x^{(1)} = m_x^{(2\rightarrow1)}+ m_x^{(1\rightarrow1)}\)
π¦ \(\quad h^{t+1, (0)}_x = \sigma(m_x^{(0)})\)
π¦ \(\quad h^{t+1, (1)}_x = \sigma(m_x^{(1)})\)
π¦ \(\quad h^{t+1, (2)}_x = \sigma(m_x^{(2)})\)
Where the notations are defined in Papillon et al : Architectures of Topological Deep Learning: A Survey of Topological Neural Networks (2023).
import numpy as np
import toponetx as tnx
import torch
from scipy.sparse import coo_matrix, diags
from topomodelx.nn.simplicial.scconv import SCConv
from topomodelx.utils.sparse import from_sparse
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
%reload_ext autoreload
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Import dataset#
The first step is to import the Karate Club ( dataset. This is a singular graph with 34 nodes that belong to two different social groups. We will use these groups for the task of node-level binary classification.
We must first lift our graph dataset into the simplicial complex domain.
dataset = tnx.datasets.karate_club(complex_type="simplicial")
Simplicial Complex with shape (34, 78, 45, 11, 2) and dimension 4
(34, 78, 45, 11, 2)
Define Neighbourhood Structures#
We create the neigborood structures expected by SSConv. The SSConv layer expects the following neighbourhood structures:
incidence_1 \(B_1\)
incidence_1_norm \(\tilde{B}_1\)
incidence_2 \(B_2\)
incidence_2_norm \(\tilde{B}_1\)
adjacency_up_0_norm \(\tilde{A}_{\uparrow,0}\)
adjacency_up_1_norm \(\tilde{A}_{\uparrow,1}\)
adjacency_down_1_norm \(\tilde{A}_{\downarrow,1}\)
adjacency_down_2_norm \(\tilde{A}_{\downarrow,2}\)
# Not working, it needs to be reviewed
def normalize_higher_order_adj(A_opt):
A_opt is an opt that maps a j-cochain to a k-cochain.
shape [num_of_k_simplices num_of_j_simplices]
D^{-0.5}* (A_opt)* D^{-0.5}.
rowsum = np.array(np.abs(A_opt).sum(1))
r_inv_sqrt = np.power(rowsum, -0.5).flatten()
r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.0
r_mat_inv_sqrt = diags(r_inv_sqrt)
A_opt_to =
return coo_matrix(A_opt_to)
incidence_1_norm = incidence_0_1 = from_sparse(dataset.incidence_matrix(1))
incidence_1 = incidence_1_0 = from_sparse(dataset.coincidence_matrix(1))
incidence_2_norm = incidence_1_2 = from_sparse(dataset.incidence_matrix(2))
incidence_2 = incidence_2_1 = from_sparse(dataset.coincidence_matrix(2))
adjacency_up_0_norm = adjacency_0 = from_sparse(dataset.up_laplacian_matrix(0))
adjacency_up_1_norm = adjacency_1_up = from_sparse(dataset.up_laplacian_matrix(1))
adjacency_down_1_norm = adjacency_1_down = from_sparse(dataset.down_laplacian_matrix(1))
adjacency_down_2_norm = adjacency_2 = from_sparse(dataset.down_laplacian_matrix(2))
Import signal#
We retrieve an input signal on the nodes, edges and faces. The signal will have shape \(n_\text{simplicial} \times\) in_channels, where in_channels is the dimension of each simplicialβs feature. Here, we have in_channels = channels_nodes $ = 2$.
x_0 = list(dataset.get_simplex_attributes("node_feat").values())
x_0 = torch.tensor(np.stack(x_0))
channels_nodes = x_0.shape[-1]
print(f"There are {x_0.shape[0]} nodes with features of dimension {x_0.shape[1]}.")
x_1 = list(dataset.get_simplex_attributes("edge_feat").values())
x_1 = torch.tensor(np.stack(x_1))
print(f"There are {x_1.shape[0]} edges with features of dimension {x_1.shape[1]}.")
x_2 = list(dataset.get_simplex_attributes("face_feat").values())
x_2 = torch.tensor(np.stack(x_2))
print(f"There are {x_2.shape[0]} faces with features of dimension {x_2.shape[1]}.")
There are 34 nodes with features of dimension 2.
There are 78 edges with features of dimension 2.
There are 45 faces with features of dimension 2.
We also pre-define the number output channels of the model, in this case the number of node classes.
in_channels = x_0.shape[-1]
out_channels = 2
Define binary labels#
We retrieve the labels associated to the nodes of each input simplex. In the KarateClub dataset, two social groups emerge. So we assign binary labels to the nodes indicating of which group they are a part.
We convert one-hot encode the binary labels, and keep the first four nodes for the purpose of testing.
y = np.array(
y_true = np.zeros((34, 2))
y_true[:, 0] = y
y_true[:, 1] = 1 - y
y_train = y_true[:30]
y_test = y_true[-4:]
y_train = torch.from_numpy(y_train)
y_test = torch.from_numpy(y_test)
Create the Neural Network#
Using the SAN class, we create our neural network with stacked layers. Given the considered dataset and task (Karate Club, node classification), a linear layer at the end produces an output with shape \(n_\text{nodes} \times 2\), so we can compare with our binary labels.
class Network(torch.nn.Module):
def __init__(self, in_channels, out_channels, n_layers=1):
self.base_model = SCConv(
self.linear_x0 = torch.nn.Linear(in_channels, out_channels)
self.linear_x1 = torch.nn.Linear(in_channels, out_channels)
self.linear_x2 = torch.nn.Linear(in_channels, out_channels)
def forward(
x_0, x_1, x_2 = self.base_model(
x_0 = self.linear_x0(x_0)
x_1 = self.linear_x1(x_1)
x_2 = self.linear_x2(x_2)
return torch.softmax(x_0, dim=1)
n_layers = 1
model = Network(
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
Train the Neural Network#
The following cell performs the training, looping over the network for a low number of epochs.
test_interval = 10
num_epochs = 200
for epoch_i in range(1, num_epochs + 1):
epoch_loss = []
y_hat = model(
loss = torch.nn.functional.binary_cross_entropy_with_logits(
y_hat[: len(y_train)].float(), y_train.float()
y_pred = torch.where(y_hat > 0.5, torch.tensor(1), torch.tensor(0))
accuracy = (y_pred[: len(y_train)] == y_train).all(dim=1).float().mean().item()
f"Epoch: {epoch_i} loss: {np.mean(epoch_loss):.4f} Train_acc: {accuracy:.4f}",
if epoch_i % test_interval == 0:
with torch.no_grad():
y_hat_test = model(
y_pred_test = torch.where(
y_hat_test > 0.5, torch.tensor(1), torch.tensor(0)
test_accuracy = (
torch.eq(y_pred_test[-len(y_test) :], y_test)
print(f"Test_acc: {test_accuracy:.4f}", flush=True)
Epoch: 1 loss: 0.7134 Train_acc: 0.5667
Epoch: 2 loss: 0.7133 Train_acc: 0.5667
Epoch: 3 loss: 0.7132 Train_acc: 0.5667
Epoch: 4 loss: 0.7131 Train_acc: 0.5667
Epoch: 5 loss: 0.7128 Train_acc: 0.5667
Epoch: 6 loss: 0.7125 Train_acc: 0.5667
Epoch: 7 loss: 0.7120 Train_acc: 0.5667
Epoch: 8 loss: 0.7117 Train_acc: 0.5667
Epoch: 9 loss: 0.7113 Train_acc: 0.5667
Epoch: 10 loss: 0.7109 Train_acc: 0.5667
Test_acc: 0.0000
Epoch: 11 loss: 0.7105 Train_acc: 0.5667
Epoch: 12 loss: 0.7099 Train_acc: 0.5667
Epoch: 13 loss: 0.7093 Train_acc: 0.5667
Epoch: 14 loss: 0.7086 Train_acc: 0.5667
Epoch: 15 loss: 0.7079 Train_acc: 0.5667
Epoch: 16 loss: 0.7070 Train_acc: 0.5667
Epoch: 17 loss: 0.7062 Train_acc: 0.5667
Epoch: 18 loss: 0.7052 Train_acc: 0.5667
Epoch: 19 loss: 0.7042 Train_acc: 0.5667
Epoch: 20 loss: 0.7030 Train_acc: 0.5667
Test_acc: 0.0000
Epoch: 21 loss: 0.7017 Train_acc: 0.5667
Epoch: 22 loss: 0.7003 Train_acc: 0.5667
Epoch: 23 loss: 0.6988 Train_acc: 0.5667
Epoch: 24 loss: 0.6971 Train_acc: 0.5667
Epoch: 25 loss: 0.6954 Train_acc: 0.5667
Epoch: 26 loss: 0.6935 Train_acc: 0.5667
Epoch: 27 loss: 0.6916 Train_acc: 0.5667
Epoch: 28 loss: 0.6894 Train_acc: 0.5667
Epoch: 29 loss: 0.6872 Train_acc: 0.5667
Epoch: 30 loss: 0.6849 Train_acc: 0.5667
Test_acc: 0.5000
Epoch: 31 loss: 0.6824 Train_acc: 0.6000
Epoch: 32 loss: 0.6799 Train_acc: 0.6000
Epoch: 33 loss: 0.6772 Train_acc: 0.6333
Epoch: 34 loss: 0.6745 Train_acc: 0.6333
Epoch: 35 loss: 0.6716 Train_acc: 0.6667
Epoch: 36 loss: 0.6687 Train_acc: 0.7000
Epoch: 37 loss: 0.6657 Train_acc: 0.7333
Epoch: 38 loss: 0.6626 Train_acc: 0.7333
Epoch: 39 loss: 0.6596 Train_acc: 0.7667
Epoch: 40 loss: 0.6564 Train_acc: 0.9333
Test_acc: 1.0000
Epoch: 41 loss: 0.6533 Train_acc: 0.9667
Epoch: 42 loss: 0.6501 Train_acc: 0.9333
Epoch: 43 loss: 0.6469 Train_acc: 0.9333
Epoch: 44 loss: 0.6437 Train_acc: 0.9333
Epoch: 45 loss: 0.6406 Train_acc: 0.9333
Epoch: 46 loss: 0.6374 Train_acc: 0.9333
Epoch: 47 loss: 0.6343 Train_acc: 0.9667
Epoch: 48 loss: 0.6313 Train_acc: 0.9667
Epoch: 49 loss: 0.6282 Train_acc: 0.9667
Epoch: 50 loss: 0.6253 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 51 loss: 0.6223 Train_acc: 0.9667
Epoch: 52 loss: 0.6195 Train_acc: 0.9667
Epoch: 53 loss: 0.6167 Train_acc: 0.9667
Epoch: 54 loss: 0.6140 Train_acc: 0.9667
Epoch: 55 loss: 0.6113 Train_acc: 0.9667
Epoch: 56 loss: 0.6087 Train_acc: 0.9667
Epoch: 57 loss: 0.6062 Train_acc: 0.9667
Epoch: 58 loss: 0.6038 Train_acc: 0.9667
Epoch: 59 loss: 0.6014 Train_acc: 0.9667
Epoch: 60 loss: 0.5991 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 61 loss: 0.5969 Train_acc: 0.9667
Epoch: 62 loss: 0.5948 Train_acc: 0.9667
Epoch: 63 loss: 0.5927 Train_acc: 0.9667
Epoch: 64 loss: 0.5907 Train_acc: 0.9667
Epoch: 65 loss: 0.5887 Train_acc: 0.9667
Epoch: 66 loss: 0.5869 Train_acc: 0.9667
Epoch: 67 loss: 0.5851 Train_acc: 0.9667
Epoch: 68 loss: 0.5833 Train_acc: 0.9667
Epoch: 69 loss: 0.5817 Train_acc: 0.9667
Epoch: 70 loss: 0.5801 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 71 loss: 0.5785 Train_acc: 0.9667
Epoch: 72 loss: 0.5770 Train_acc: 0.9667
Epoch: 73 loss: 0.5756 Train_acc: 0.9667
Epoch: 74 loss: 0.5742 Train_acc: 0.9667
Epoch: 75 loss: 0.5728 Train_acc: 0.9667
Epoch: 76 loss: 0.5715 Train_acc: 0.9667
Epoch: 77 loss: 0.5703 Train_acc: 0.9667
Epoch: 78 loss: 0.5691 Train_acc: 0.9667
Epoch: 79 loss: 0.5679 Train_acc: 0.9667
Epoch: 80 loss: 0.5668 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 81 loss: 0.5657 Train_acc: 0.9667
Epoch: 82 loss: 0.5647 Train_acc: 0.9667
Epoch: 83 loss: 0.5637 Train_acc: 0.9667
Epoch: 84 loss: 0.5627 Train_acc: 0.9667
Epoch: 85 loss: 0.5618 Train_acc: 0.9667
Epoch: 86 loss: 0.5609 Train_acc: 0.9667
Epoch: 87 loss: 0.5601 Train_acc: 0.9667
Epoch: 88 loss: 0.5592 Train_acc: 0.9667
Epoch: 89 loss: 0.5584 Train_acc: 0.9667
Epoch: 90 loss: 0.5577 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 91 loss: 0.5569 Train_acc: 0.9667
Epoch: 92 loss: 0.5562 Train_acc: 0.9667
Epoch: 93 loss: 0.5555 Train_acc: 0.9667
Epoch: 94 loss: 0.5548 Train_acc: 0.9667
Epoch: 95 loss: 0.5542 Train_acc: 0.9667
Epoch: 96 loss: 0.5535 Train_acc: 0.9667
Epoch: 97 loss: 0.5529 Train_acc: 0.9667
Epoch: 98 loss: 0.5523 Train_acc: 0.9667
Epoch: 99 loss: 0.5518 Train_acc: 0.9667
Epoch: 100 loss: 0.5512 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 101 loss: 0.5507 Train_acc: 0.9667
Epoch: 102 loss: 0.5502 Train_acc: 0.9667
Epoch: 103 loss: 0.5497 Train_acc: 0.9667
Epoch: 104 loss: 0.5492 Train_acc: 0.9667
Epoch: 105 loss: 0.5487 Train_acc: 0.9667
Epoch: 106 loss: 0.5483 Train_acc: 0.9667
Epoch: 107 loss: 0.5478 Train_acc: 0.9667
Epoch: 108 loss: 0.5474 Train_acc: 0.9667
Epoch: 109 loss: 0.5470 Train_acc: 0.9667
Epoch: 110 loss: 0.5466 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 111 loss: 0.5462 Train_acc: 0.9667
Epoch: 112 loss: 0.5458 Train_acc: 0.9667
Epoch: 113 loss: 0.5455 Train_acc: 0.9667
Epoch: 114 loss: 0.5451 Train_acc: 0.9667
Epoch: 115 loss: 0.5448 Train_acc: 0.9667
Epoch: 116 loss: 0.5444 Train_acc: 0.9667
Epoch: 117 loss: 0.5441 Train_acc: 0.9667
Epoch: 118 loss: 0.5438 Train_acc: 0.9667
Epoch: 119 loss: 0.5435 Train_acc: 0.9667
Epoch: 120 loss: 0.5432 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 121 loss: 0.5429 Train_acc: 0.9667
Epoch: 122 loss: 0.5426 Train_acc: 0.9667
Epoch: 123 loss: 0.5423 Train_acc: 0.9667
Epoch: 124 loss: 0.5421 Train_acc: 0.9667
Epoch: 125 loss: 0.5418 Train_acc: 0.9667
Epoch: 126 loss: 0.5416 Train_acc: 0.9667
Epoch: 127 loss: 0.5413 Train_acc: 0.9667
Epoch: 128 loss: 0.5411 Train_acc: 0.9667
Epoch: 129 loss: 0.5408 Train_acc: 0.9667
Epoch: 130 loss: 0.5406 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 131 loss: 0.5404 Train_acc: 0.9667
Epoch: 132 loss: 0.5402 Train_acc: 0.9667
Epoch: 133 loss: 0.5400 Train_acc: 0.9667
Epoch: 134 loss: 0.5398 Train_acc: 0.9667
Epoch: 135 loss: 0.5395 Train_acc: 0.9667
Epoch: 136 loss: 0.5393 Train_acc: 0.9667
Epoch: 137 loss: 0.5392 Train_acc: 0.9667
Epoch: 138 loss: 0.5390 Train_acc: 0.9667
Epoch: 139 loss: 0.5388 Train_acc: 0.9667
Epoch: 140 loss: 0.5386 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 141 loss: 0.5384 Train_acc: 0.9667
Epoch: 142 loss: 0.5383 Train_acc: 0.9667
Epoch: 143 loss: 0.5381 Train_acc: 0.9667
Epoch: 144 loss: 0.5379 Train_acc: 0.9667
Epoch: 145 loss: 0.5378 Train_acc: 0.9667
Epoch: 146 loss: 0.5376 Train_acc: 0.9667
Epoch: 147 loss: 0.5374 Train_acc: 0.9667
Epoch: 148 loss: 0.5373 Train_acc: 0.9667
Epoch: 149 loss: 0.5371 Train_acc: 0.9667
Epoch: 150 loss: 0.5370 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 151 loss: 0.5369 Train_acc: 0.9667
Epoch: 152 loss: 0.5367 Train_acc: 0.9667
Epoch: 153 loss: 0.5366 Train_acc: 0.9667
Epoch: 154 loss: 0.5364 Train_acc: 0.9667
Epoch: 155 loss: 0.5363 Train_acc: 0.9667
Epoch: 156 loss: 0.5362 Train_acc: 0.9667
Epoch: 157 loss: 0.5361 Train_acc: 0.9667
Epoch: 158 loss: 0.5359 Train_acc: 0.9667
Epoch: 159 loss: 0.5358 Train_acc: 0.9667
Epoch: 160 loss: 0.5357 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 161 loss: 0.5356 Train_acc: 0.9667
Epoch: 162 loss: 0.5355 Train_acc: 0.9667
Epoch: 163 loss: 0.5354 Train_acc: 0.9667
Epoch: 164 loss: 0.5352 Train_acc: 0.9667
Epoch: 165 loss: 0.5351 Train_acc: 0.9667
Epoch: 166 loss: 0.5350 Train_acc: 0.9667
Epoch: 167 loss: 0.5349 Train_acc: 0.9667
Epoch: 168 loss: 0.5348 Train_acc: 0.9667
Epoch: 169 loss: 0.5347 Train_acc: 0.9667
Epoch: 170 loss: 0.5346 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 171 loss: 0.5345 Train_acc: 0.9667
Epoch: 172 loss: 0.5344 Train_acc: 0.9667
Epoch: 173 loss: 0.5343 Train_acc: 0.9667
Epoch: 174 loss: 0.5342 Train_acc: 0.9667
Epoch: 175 loss: 0.5341 Train_acc: 0.9667
Epoch: 176 loss: 0.5341 Train_acc: 0.9667
Epoch: 177 loss: 0.5340 Train_acc: 0.9667
Epoch: 178 loss: 0.5339 Train_acc: 0.9667
Epoch: 179 loss: 0.5338 Train_acc: 0.9667
Epoch: 180 loss: 0.5337 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 181 loss: 0.5336 Train_acc: 0.9667
Epoch: 182 loss: 0.5335 Train_acc: 0.9667
Epoch: 183 loss: 0.5335 Train_acc: 0.9667
Epoch: 184 loss: 0.5334 Train_acc: 0.9667
Epoch: 185 loss: 0.5333 Train_acc: 0.9667
Epoch: 186 loss: 0.5332 Train_acc: 0.9667
Epoch: 187 loss: 0.5332 Train_acc: 0.9667
Epoch: 188 loss: 0.5331 Train_acc: 0.9667
Epoch: 189 loss: 0.5330 Train_acc: 0.9667
Epoch: 190 loss: 0.5329 Train_acc: 0.9667
Test_acc: 1.0000
Epoch: 191 loss: 0.5329 Train_acc: 0.9667
Epoch: 192 loss: 0.5328 Train_acc: 0.9667
Epoch: 193 loss: 0.5327 Train_acc: 0.9667
Epoch: 194 loss: 0.5327 Train_acc: 0.9667
Epoch: 195 loss: 0.5326 Train_acc: 0.9667
Epoch: 196 loss: 0.5325 Train_acc: 0.9667
Epoch: 197 loss: 0.5325 Train_acc: 0.9667
Epoch: 198 loss: 0.5324 Train_acc: 0.9667
Epoch: 199 loss: 0.5323 Train_acc: 0.9667
Epoch: 200 loss: 0.5323 Train_acc: 0.9667
Test_acc: 1.0000
[ ]: