"""HyperGAT layer."""
from typing import Literal
import torch
from topomodelx.base.message_passing import MessagePassing
[docs]
class HyperGATLayer(MessagePassing):
r"""Implementation of the HyperGAT layer proposed in [1]_.
Parameters
----------
in_channels : int
Dimension of the input features.
hidden_channels : int
Dimension of the output features.
update_func : str, default = "relu"
Update method to apply to message.
initialization : Literal["xavier_uniform", "xavier_normal"], default="xavier_uniform"
Initialization method.
initialization_gain : float, default=1.414
Gain for the initialization.
**kwargs : optional
Additional arguments for the layer modules.
References
----------
.. [1] Ding, Wang, Li, Li and Huan Liu.
EMNLP, 2020.
https://aclanthology.org/2020.emnlp-main.399.pdf
"""
def __init__(
self,
in_channels,
hidden_channels,
update_func: str = "relu",
initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform",
initialization_gain: float = 1.414,
**kwargs,
) -> None:
super().__init__(
initialization=initialization, initialization_gain=initialization_gain
)
self.in_channels = in_channels
self.hidden_channels = hidden_channels
self.update_func = update_func
self.weight1 = torch.nn.Parameter(
torch.Tensor(self.in_channels, self.hidden_channels)
)
self.weight2 = torch.nn.Parameter(
torch.Tensor(self.hidden_channels, self.hidden_channels)
)
self.att_weight1 = torch.nn.Parameter(torch.zeros(size=(hidden_channels, 1)))
self.att_weight2 = torch.nn.Parameter(
torch.zeros(size=(2 * hidden_channels, 1))
)
self.reset_parameters()
[docs]
def reset_parameters(self):
r"""Reset parameters."""
if self.initialization == "xavier_uniform":
torch.nn.init.xavier_uniform_(self.weight1, gain=self.initialization_gain)
torch.nn.init.xavier_uniform_(self.weight2, gain=self.initialization_gain)
torch.nn.init.xavier_uniform_(
self.att_weight1.view(-1, 1), gain=self.initialization_gain
)
torch.nn.init.xavier_uniform_(
self.att_weight2.view(-1, 1), gain=self.initialization_gain
)
elif self.initialization == "xavier_normal":
torch.nn.init.xavier_normal_(self.weight1, gain=self.initialization_gain)
torch.nn.init.xavier_normal_(self.weight2, gain=self.initialization_gain)
torch.nn.init.xavier_normal_(
self.att_weight1.view(-1, 1), gain=self.initialization_gain
)
torch.nn.init.xavier_normal_(
self.att_weight2.view(-1, 1), gain=self.initialization_gain
)
else:
raise ValueError(
"Initialization method not recognized. "
"Should be either xavier_uniform or xavier_normal."
)
[docs]
def attention(
self,
x_source,
x_target=None,
mechanism: Literal["node-level", "edge-level"] = "node-level",
):
r"""Compute attention weights for messages, as proposed in [1].
Parameters
----------
x_source : torch.Tensor, shape = (n_source_cells, in_channels)
Input features on source cells.
Assumes that all source cells have the same rank r.
x_target : torch.Tensor, shape = (n_target_cells, in_channels)
Input features on source cells.
Assumes that all source cells have the same rank r.
mechanism : Literal["node-level", "edge-level"], default = "node-level"
Attention mechanism as proposed in [1]. If set to "node-level", will compute node-level attention,
if set to "edge-level", will compute edge-level attention (see [1]).
Returns
-------
torch.Tensor, shape = (n_messages, 1)
Attention weights: one scalar per message between a source and a target cell.
"""
if mechanism == "node-level":
x_source_per_message = x_source[self.target_index_i]
return torch.nn.functional.softmax(
torch.matmul(
torch.nn.functional.leaky_relu(x_source_per_message),
self.att_weight1,
),
dim=1,
)
x_source_per_message = x_source[self.source_index_j]
x_target_per_message = (
x_source[self.target_index_i]
if x_target is None
else x_target[self.target_index_i]
)
x_source_target_per_message = torch.nn.functional.leaky_relu(
torch.cat([x_source_per_message, x_target_per_message], dim=1)
)
return torch.nn.functional.softmax(
torch.matmul(x_source_target_per_message, self.att_weight2), dim=1
)
[docs]
def update(self, x_message_on_target):
r"""Update embeddings on each cell (step 4).
Parameters
----------
x_message_on_target : torch.Tensor, shape = (n_target_cells, hidden_channels)
Output features on target cells.
Returns
-------
torch.Tensor, shape = (n_target_cells, hidden_channels)
Updated output features on target cells.
"""
if self.update_func == "sigmoid":
return torch.sigmoid(x_message_on_target)
if self.update_func == "relu":
return torch.nn.functional.relu(x_message_on_target)
return None
[docs]
def forward(self, x_0, incidence_1):
r"""Forward pass.
.. math::
\begin{align*}
&🟥 \quad m_{y \rightarrow z}^{(0 \rightarrow 1) } = (B^T_1\odot att(h_{y \in \mathcal{B}(z)}^{t,(0)}))\_{zy} \cdot h^{t,(0)}y \cdot \Theta^{t,(0)}\\
&🟧 \quad m_z^{(1)} = \sigma(\sum_{y \in \mathcal{B}(z)} m_{y \rightarrow z}^{(0 \rightarrow 1)})\\
&🟥 \quad m_{z \rightarrow x}^{(1 \rightarrow 0)} = (B_1 \odot att(h_{z \in \mathcal{C}(x)}^{t,(1)}))\_{xz} \cdot m_{z}^{(1)} \cdot \Theta^{t,(1)}\\
&🟧 \quad m_{x}^{(0)} = \sum_{z \in \mathcal{C}(x)} m_{z \rightarrow x}^{(1\rightarrow0)}\\
&🟩 \quad m_x = m_{x}^{(0)}\\
&🟦 \quad h_x^{t+1, (0)} = \sigma(m_x)
\end{align*}
Parameters
----------
x_0 : torch.Tensor
Input features.
incidence_1 : torch.sparse
Incidence matrix between nodes and hyperedges.
Returns
-------
x_0 : torch.Tensor
Output node features.
x_1 : torch.Tensor
Output hyperedge features.
"""
intra_aggregation = incidence_1.t() @ (x_0 @ self.weight1)
self.target_index_i, self.source_index_j = incidence_1.indices()
attention_values = self.attention(intra_aggregation).squeeze()
incidence_with_attention = torch.sparse_coo_tensor(
indices=incidence_1.indices(),
values=incidence_1.values() * attention_values,
size=incidence_1.shape,
)
intra_aggregation_with_attention = incidence_with_attention.t() @ (
x_0 @ self.weight1
)
messages_on_edges = self.update(intra_aggregation_with_attention)
inter_aggregation = incidence_1 @ (messages_on_edges @ self.weight2)
attention_values = self.attention(
inter_aggregation, intra_aggregation
).squeeze()
incidence_with_attention = torch.sparse_coo_tensor(
indices=incidence_1.indices(),
values=attention_values * incidence_1.values(),
size=incidence_1.shape,
)
inter_aggregation_with_attention = incidence_with_attention @ (
messages_on_edges @ self.weight2
)
return self.update(inter_aggregation_with_attention), messages_on_edges