This module implements custom Python classes to represent models leveraging pytorch-lightning within TopoBenchmarkX.
This module defines the TBXModel class.
- class topobenchmarkx.model.model.TBXModel(backbone: Module, backbone_wrapper: Module, readout: Module, loss: Module, feature_encoder: Module | None = None, evaluator: Any = None, optimizer: Any = None, **kwargs)[source]#
A LightningModule to define a network.
- Parameters:
- backbonetorch.nn.Module
The backbone model to train.
- backbone_wrappertorch.nn.Module
The backbone wrapper class.
- readouttorch.nn.Module
The readout class.
- losstorch.nn.Module
The loss class.
- feature_encodertorch.nn.Module, optional
The feature encoder (default: None).
- evaluatorAny, optional
The evaluator class (default: None).
- optimizerAny, optional
The optimizer class (default: None).
- **kwargsAny
Additional keyword arguments.
- configure_optimizers() dict[str, Any] [source]#
Configure optimizers and learning-rate schedulers.
Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple.
- Returns:
- dict:
A dict containing the configured optimizers and learning-rate schedulers to be used for training.
- forward(batch: Data) dict [source]#
Perform a forward pass through the model self.backbone.
- Parameters:
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the model output.
- log_metrics(mode=None)[source]#
Log metrics.
- Parameters:
- modestr, optional
The mode of the model, either “train”, “val”, or “test” (default: None).
- model_step(batch: Data) dict [source]#
Perform a single model step on a batch of data.
- Parameters:
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the model output and the loss.
- on_test_epoch_end() None [source]#
Lightning hook that is called when a test epoch ends.
This hook is used to log the test metrics.
- on_test_epoch_start() None [source]#
Lightning hook that is called when a test epoch begins.
This hook is used to reset the test metrics.
- on_train_epoch_end() None [source]#
Lightning hook that is called when a train epoch ends.
This hook is used to log the train metrics.
- on_train_epoch_start() None [source]#
Lightning hook that is called when a train epoch begins.
This hook is used to reset the train metrics.
- on_val_epoch_start() None [source]#
Lightning hook that is called when a validation epoch begins.
This hook is used to reset the validation metrics.
- on_validation_epoch_end() None [source]#
Lightning hook that is called when a validation epoch ends.
This hook is used to log the validation metrics.
- on_validation_epoch_start() None [source]#
Hook called when a validation epoch begins.
According pytorch lightning documentation this hook is called at the beginning of the validation epoch.
Note that the validation step is within the train epoch. Hence here we have to log the train metrics before we reset the evaluator to start the validation loop.
- process_outputs(model_out: dict, batch: Data) dict [source]#
Handle model outputs.
- Parameters:
- model_outdict
Dictionary containing the model output.
Batch object containing the batched data.
- Returns:
- dict
Dictionary containing the updated model output.
- setup(stage: str) None [source]#
Hook to call torch.compile.
Lightning hook that is called at the beginning of fit (train + validate), validate, test, or predict.
This is a good hook when you need to build models dynamically or adjust something about them. This hook is called on every process when using DDP.
- Parameters:
- stagestr
Either “fit”, “validate”, “test”, or “predict”.
- test_step(batch: Data, batch_idx: int) None [source]#
Perform a single test step on a batch of data.
- Parameters:
Batch object containing the batched data.
- batch_idxint
The index of the current batch.
- training_step(batch: Data, batch_idx: int) Tensor [source]#
Perform a single training step on a batch of data.
- Parameters:
Batch object containing the batched data.
- batch_idxint
The index of the current batch.
- Returns:
- torch.Tensor
A tensor of losses between model predictions and targets.