Model#
This module implements custom Python classes to represent models leveraging pytorch-lightning within TopoBenchmarkX.
This module defines the TBXModel class.
- class topobenchmarkx.model.model.TBXModel(backbone: Module, backbone_wrapper: Module, readout: Module, loss: Module, feature_encoder: Module | None = None, evaluator: Any = None, optimizer: Any = None, **kwargs)[source]#
A LightningModule to define a network.
- Parameters:
- backbonetorch.nn.Module
The backbone model to train.
- backbone_wrappertorch.nn.Module
The backbone wrapper class.
- readouttorch.nn.Module
The readout class.
- losstorch.nn.Module
The loss class.
- feature_encodertorch.nn.Module, optional
The feature encoder (default: None).
- evaluatorAny, optional
The evaluator class (default: None).
- optimizerAny, optional
The optimizer class (default: None).
- **kwargsAny
Additional keyword arguments.
- configure_optimizers() dict[str, Any] [source]#
Configure optimizers and learning-rate schedulers.
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
Examples
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
- Returns:
- dict:
A dict containing the configured optimizers and learning-rate schedulers to be used for training.
- forward(batch: Data) dict [source]#
Perform a forward pass through the model self.backbone.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the model output.
- log_metrics(mode=None)[source]#
Log metrics.
- Parameters:
- modestr, optional
The mode of the model, either “train”, “val”, or “test” (default: None).
- model_step(batch: Data) dict [source]#
Perform a single model step on a batch of data.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the model output and the loss.
- on_test_epoch_end() None [source]#
Lightning hook that is called when a test epoch ends.
This hook is used to log the test metrics.
- on_test_epoch_start() None [source]#
Lightning hook that is called when a test epoch begins.
This hook is used to reset the test metrics.
- on_train_epoch_end() None [source]#
Lightning hook that is called when a train epoch ends.
This hook is used to log the train metrics.
- on_train_epoch_start() None [source]#
Lightning hook that is called when a train epoch begins.
This hook is used to reset the train metrics.
- on_val_epoch_start() None [source]#
Lightning hook that is called when a validation epoch begins.
This hook is used to reset the validation metrics.
- on_validation_epoch_end() None [source]#
Lightning hook that is called when a validation epoch ends.
This hook is used to log the validation metrics.
- on_validation_epoch_start() None [source]#
Hook called when a validation epoch begins.
According pytorch lightning documentation this hook is called at the beginning of the validation epoch.
https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#hooks
Note that the validation step is within the train epoch. Hence here we have to log the train metrics before we reset the evaluator to start the validation loop.
- process_outputs(model_out: dict, batch: Data) dict [source]#
Handle model outputs.
- Parameters:
- model_outdict
Dictionary containing the model output.
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- setup(stage: str) None [source]#
Hook to call torch.compile.
Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.
This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
- stagestr
Either “fit”, “validate”, “test”, or “predict”.
- test_step(batch: Data, batch_idx: int) None [source]#
Perform a single test step on a batch of data.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- batch_idxint
The index of the current batch.
- training_step(batch: Data, batch_idx: int) Tensor [source]#
Perform a single training step on a batch of data.
- Parameters:
- batchtorch_geometric.data.Data
Batch object containing the batched data.
- batch_idxint
The index of the current batch.
- Returns:
- torch.Tensor
A tensor of losses between model predictions and targets.