Aggregation#

Aggregation module.

class topomodelx.base.aggregation.Aggregation(aggr_func: Literal['mean', 'sum'] = 'sum', update_func: Literal['relu', 'sigmoid', 'tanh'] | None = 'sigmoid')[source]#

Message passing layer.

Parameters:
aggr_func{“mean”, “sum”}, default=”sum”

Aggregation method (Inter-neighborhood).

update_func{“relu”, “sigmoid”, “tanh”, None}, default=”sigmoid”

Update method to apply to merged message.

forward(x)[source]#

Forward pass.

Parameters:
xlist

A list of messages to merge. Each message has shape [n_skeleton_in, channels] and len = n_messages_to_merge.

Returns:
torch.Tensor

Aggregated messages.

update(inputs)[source]#

Update (Step 4).

Parameters:
inputtorch.Tensor

Features for the update step.

Returns:
torch.Tensor

Updated features with the same shape as input.