Loss
This module implements custom Python classes to compute losses in TopoBenchmarkX.
Abstract class for the loss class.
-
class topobenchmarkx.loss.base.AbstractLoss[source]
Abstract class for the loss class.
-
abstract forward(model_out: dict, batch: Data)[source]
Forward pass.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
Loss module for the topobenchmarkx package.
-
class topobenchmarkx.loss.loss.TBXLoss(task, loss_type=None)[source]
Defines the default model loss for the given task.
- Parameters:
- taskstr
Task type, either “classification” or “regression”.
- loss_typestr, optional
Loss type, either “cross_entropy”, “mse”, or “mae” (default: None).
-
forward(model_out: dict, batch: Data)[source]
Forward pass of the loss function.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched domain data.
- Returns:
- dict
Dictionary containing the model output with the loss.