DataLoader#
The dataloader module implements custom dataloaders for training.
Dataset class compatible with TBXDataloader.
- class topobenchmarkx.dataloader.dataload_dataset.DataloadDataset(data_lst)[source]#
Custom dataset to return all the values added to the dataset object.
- Parameters:
- data_lstlist[torch_geometric.data.Data]
List of torch_geometric.data.Data objects.
TBXDataloader class.
- class topobenchmarkx.dataloader.dataloader.TBXDataloader(dataset_train: DataloadDataset, dataset_val: DataloadDataset = None, dataset_test: DataloadDataset = None, batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, **kwargs: Any)[source]#
This class takes care of returning the dataloaders for the training, validation, and test datasets.
It also handles the collate function. The class is designed to work with the torch dataloaders.
- Parameters:
- dataset_trainDataloadDataset
The training dataset.
- dataset_valDataloadDataset, optional
The validation dataset (default: None).
- dataset_testDataloadDataset, optional
The test dataset (default: None).
- batch_sizeint, optional
The batch size for the dataloader (default: 1).
- num_workersint, optional
The number of worker processes to use for data loading (default: 0).
- pin_memorybool, optional
If True, the data loader will copy tensors into pinned memory before returning them (default: False).
- **kwargsoptional
Additional arguments.
References
- state_dict() dict[Any, Any] [source]#
Called when saving a checkpoint. Implement to generate and save the datamodule state.
- Returns:
- dict
A dictionary containing the datamodule state that you want to save.
- teardown(stage: str | None = None) None [source]#
Lightning hook for cleaning up after trainer.fit(), trainer.validate(), trainer.test(), and trainer.predict().
- Parameters:
- stagestr, optional
The stage being torn down. Either “fit”, “validate”, “test”, or “predict” (default: None).
- test_dataloader() DataLoader [source]#
Create and return the test dataloader.
- Returns:
- torch.utils.data.DataLoader
The test dataloader.
Dataloader utilities.
- class topobenchmarkx.dataloader.utils.DomainData(x: Tensor | None = None, edge_index: Tensor | None = None, edge_attr: Tensor | None = None, y: Tensor | int | float | None = None, pos: Tensor | None = None, time: Tensor | None = None, **kwargs)[source]#
Helper Data class so that not only sparse matrices with adj in the name can work with PyG dataloaders.
It overwrites some methods from torch_geometric.data.Data
- topobenchmarkx.dataloader.utils.collate_fn(batch)[source]#
Overwrite torch_geometric.data.DataLoader collate function to use the DomainData class.
This ensures that the torch_geometric dataloaders work with sparse matrices that are not necessarily named adj. The function also generates the batch slices for the different cell dimensions.
- Parameters:
- batchlist
List of data objects (e.g., torch_geometric.data.Data).
- Returns:
- torch_geometric.data.Batch
A torch_geometric.data.Batch object.