DataLoader#

The dataloader module implements custom dataloaders for training.

Dataset class compatible with TBXDataloader.

class topobenchmarkx.dataloader.dataload_dataset.DataloadDataset(data_lst)[source]#

Custom dataset to return all the values added to the dataset object.

Parameters:
data_lstlist[torch_geometric.data.Data]

List of torch_geometric.data.Data objects.

get(idx)[source]#

Get data object from data list.

Parameters:
idxint

Index of the data object to get.

Returns:
tuple

Tuple containing a list of all the values for the data and the corresponding keys.

len()[source]#

Return the length of the dataset.

Returns:
int

Length of the dataset.

TBXDataloader class.

class topobenchmarkx.dataloader.dataloader.TBXDataloader(dataset_train: DataloadDataset, dataset_val: DataloadDataset = None, dataset_test: DataloadDataset = None, batch_size: int = 1, num_workers: int = 0, pin_memory: bool = False, **kwargs: Any)[source]#

This class takes care of returning the dataloaders for the training, validation, and test datasets.

It also handles the collate function. The class is designed to work with the torch dataloaders.

Parameters:
dataset_trainDataloadDataset

The training dataset.

dataset_valDataloadDataset, optional

The validation dataset (default: None).

dataset_testDataloadDataset, optional

The test dataset (default: None).

batch_sizeint, optional

The batch size for the dataloader (default: 1).

num_workersint, optional

The number of worker processes to use for data loading (default: 0).

pin_memorybool, optional

If True, the data loader will copy tensors into pinned memory before returning them (default: False).

**kwargsoptional

Additional arguments.

References

Read the docs:

https://lightning.ai/docs/pytorch/latest/data/datamodule.html

state_dict() dict[Any, Any][source]#

Called when saving a checkpoint. Implement to generate and save the datamodule state.

Returns:
dict

A dictionary containing the datamodule state that you want to save.

teardown(stage: str | None = None) None[source]#

Lightning hook for cleaning up after trainer.fit(), trainer.validate(), trainer.test(), and trainer.predict().

Parameters:
stagestr, optional

The stage being torn down. Either “fit”, “validate”, “test”, or “predict” (default: None).

test_dataloader() DataLoader[source]#

Create and return the test dataloader.

Returns:
torch.utils.data.DataLoader

The test dataloader.

train_dataloader() DataLoader[source]#

Create and return the train dataloader.

Returns:
torch.utils.data.DataLoader

The train dataloader.

val_dataloader() DataLoader[source]#

Create and return the validation dataloader.

Returns:
torch.utils.data.DataLoader

The validation dataloader.

Dataloader utilities.

class topobenchmarkx.dataloader.utils.DomainData(x: Tensor | None = None, edge_index: Tensor | None = None, edge_attr: Tensor | None = None, y: Tensor | int | float | None = None, pos: Tensor | None = None, time: Tensor | None = None, **kwargs)[source]#

Helper Data class so that not only sparse matrices with adj in the name can work with PyG dataloaders.

It overwrites some methods from torch_geometric.data.Data

is_valid(string)[source]#

Check if the string contains any of the valid names.

Parameters:
stringstr

String to check.

Returns:
bool

Whether the string contains any of the valid names.

topobenchmarkx.dataloader.utils.collate_fn(batch)[source]#

Overwrite torch_geometric.data.DataLoader collate function to use the DomainData class.

This ensures that the torch_geometric dataloaders work with sparse matrices that are not necessarily named adj. The function also generates the batch slices for the different cell dimensions.

Parameters:
batchlist

List of data objects (e.g., torch_geometric.data.Data).

Returns:
torch_geometric.data.Batch

A torch_geometric.data.Batch object.

topobenchmarkx.dataloader.utils.to_data_list(batch)[source]#

Workaround needed since torch_geometric doesn’t work when using torch.sparse instead of torch_sparse.

Parameters:
batchtorch_geometric.data.Batch

The batch of data.

Returns:
list

List of data objects.