Loss#

This module implements custom Python classes to compute losses in TopoBenchmarkX.

Abstract class for the loss class.

class topobenchmarkx.loss.base.AbstractLoss[source]#

Abstract class for the loss class.

abstract forward(model_out: dict, batch: Data)[source]#

Forward pass.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Loss module for the topobenchmarkx package.

class topobenchmarkx.loss.loss.TBXLoss(task, loss_type=None)[source]#

Defines the default model loss for the given task.

Parameters:
taskstr

Task type, either “classification” or “regression”.

loss_typestr, optional

Loss type, either “cross_entropy”, “mse”, or “mae” (default: None).

forward(model_out: dict, batch: Data)[source]#

Forward pass of the loss function.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched domain data.

Returns:
dict

Dictionary containing the model output with the loss.