Model#

This module implements custom Python classes to represent models leveraging pytorch-lightning within TopoBenchmarkX.

This module defines the TBXModel class.

class topobenchmarkx.model.model.TBXModel(backbone: Module, backbone_wrapper: Module, readout: Module, loss: Module, feature_encoder: Module | None = None, evaluator: Any = None, optimizer: Any = None, **kwargs)[source]#

A LightningModule to define a network.

Parameters:
backbonetorch.nn.Module

The backbone model to train.

backbone_wrappertorch.nn.Module

The backbone wrapper class.

readouttorch.nn.Module

The readout class.

losstorch.nn.Module

The loss class.

feature_encodertorch.nn.Module, optional

The feature encoder (default: None).

evaluatorAny, optional

The evaluator class (default: None).

optimizerAny, optional

The optimizer class (default: None).

**kwargsAny

Additional keyword arguments.

configure_optimizers() dict[str, Any][source]#

Configure optimizers and learning-rate schedulers.

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.

Examples

https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

Returns:
dict:

A dict containing the configured optimizers and learning-rate schedulers to be used for training.

forward(batch: Data) dict[source]#

Perform a forward pass through the model self.backbone.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the model output.

log_metrics(mode=None)[source]#

Log metrics.

Parameters:
modestr, optional

The mode of the model, either “train”, “val”, or “test” (default: None).

model_step(batch: Data) dict[source]#

Perform a single model step on a batch of data.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the model output and the loss.

on_test_epoch_end() None[source]#

Lightning hook that is called when a test epoch ends.

This hook is used to log the test metrics.

on_test_epoch_start() None[source]#

Lightning hook that is called when a test epoch begins.

This hook is used to reset the test metrics.

on_train_epoch_end() None[source]#

Lightning hook that is called when a train epoch ends.

This hook is used to log the train metrics.

on_train_epoch_start() None[source]#

Lightning hook that is called when a train epoch begins.

This hook is used to reset the train metrics.

on_val_epoch_start() None[source]#

Lightning hook that is called when a validation epoch begins.

This hook is used to reset the validation metrics.

on_validation_epoch_end() None[source]#

Lightning hook that is called when a validation epoch ends.

This hook is used to log the validation metrics.

on_validation_epoch_start() None[source]#

Hook called when a validation epoch begins.

According pytorch lightning documentation this hook is called at the beginning of the validation epoch.

https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks

Note that the validation step is within the train epoch. Hence here we have to log the train metrics before we reset the evaluator to start the validation loop.

process_outputs(model_out: dict, batch: Data) dict[source]#

Handle model outputs.

Parameters:
model_outdict

Dictionary containing the model output.

batchtorch_geometric.data.Data

Batch object containing the batched data.

Returns:
dict

Dictionary containing the updated model output.

setup(stage: str) None[source]#

Hook to call torch.compile.

Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.

This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.

Parameters:
stagestr

Either “fit”, “validate”, “test”, or “predict”.

test_step(batch: Data, batch_idx: int) None[source]#

Perform a single test step on a batch of data.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

batch_idxint

The index of the current batch.

training_step(batch: Data, batch_idx: int) Tensor[source]#

Perform a single training step on a batch of data.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

batch_idxint

The index of the current batch.

Returns:
torch.Tensor

A tensor of losses between model predictions and targets.

validation_step(batch: Data, batch_idx: int) None[source]#

Perform a single validation step on a batch of data.

Parameters:
batchtorch_geometric.data.Data

Batch object containing the batched data.

batch_idxint

The index of the current batch.