"""Split utilities."""
import os
import numpy as np
import torch
from sklearn.model_selection import StratifiedKFold
from topobenchmarkx.dataloader import DataloadDataset
# Generate splits in different fasions
[docs]
def k_fold_split(labels, parameters):
"""Return train and valid indices as in K-Fold Cross-Validation.
If the split already exists it loads it automatically, otherwise it creates the
split file for the subsequent runs.
Parameters
----------
labels : torch.Tensor
Label tensor.
parameters : DictConfig
Configuration parameters.
Returns
-------
dict
Dictionary containing the train, validation and test indices, with keys "train", "valid", and "test".
"""
data_dir = parameters.data_split_dir
k = parameters.k
fold = parameters.data_seed
assert fold < k, "data_seed needs to be less than k"
torch.manual_seed(0)
np.random.seed(0)
split_dir = os.path.join(data_dir, f"{k}-fold")
if not os.path.isdir(split_dir):
os.makedirs(split_dir)
split_path = os.path.join(split_dir, f"{fold}.npz")
if not os.path.isfile(split_path):
n = labels.shape[0]
x_idx = np.arange(n)
x_idx = np.random.permutation(x_idx)
labels = labels[x_idx]
skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)
for fold_n, (train_idx, valid_idx) in enumerate(
skf.split(x_idx, labels)
):
split_idx = {
"train": train_idx,
"valid": valid_idx,
"test": valid_idx,
}
# Check that all nodes/graph have been assigned to some split
assert np.all(
np.sort(
np.array(
split_idx["train"].tolist()
+ split_idx["valid"].tolist()
)
)
== np.sort(np.arange(len(labels)))
), "Not every sample has been loaded."
split_path = os.path.join(split_dir, f"{fold_n}.npz")
np.savez(split_path, **split_idx)
split_path = os.path.join(split_dir, f"{fold}.npz")
split_idx = np.load(split_path)
# Check that all nodes/graph have been assigned to some split
assert (
np.unique(
np.array(
split_idx["train"].tolist()
+ split_idx["valid"].tolist()
+ split_idx["test"].tolist()
)
).shape[0]
== labels.shape[0]
), "Not all nodes within splits"
return split_idx
[docs]
def random_splitting(labels, parameters, global_data_seed=42):
r"""Randomly splits label into train/valid/test splits.
Adapted from https://github.com/CUAI/Non-Homophily-Benchmarks.
Parameters
----------
labels : torch.Tensor
Label tensor.
parameters : DictConfig
Configuration parameter.
global_data_seed : int
Seed for the random number generator.
Returns
-------
dict:
Dictionary containing the train, validation and test indices with keys "train", "valid", and "test".
"""
fold = parameters["data_seed"]
data_dir = parameters["data_split_dir"]
train_prop = parameters["train_prop"]
valid_prop = (1 - train_prop) / 2
# Create split directory if it does not exist
split_dir = os.path.join(
data_dir, f"train_prop={train_prop}_global_seed={global_data_seed}"
)
generate_splits = False
if not os.path.isdir(split_dir):
os.makedirs(split_dir)
generate_splits = True
# Generate splits if they do not exist
if generate_splits:
# Set initial seed
torch.manual_seed(global_data_seed)
np.random.seed(global_data_seed)
# Generate a split
n = labels.shape[0]
train_num = int(n * train_prop)
valid_num = int(n * valid_prop)
# Generate 10 splits
for fold_n in range(10):
# Permute indices
perm = torch.as_tensor(np.random.permutation(n))
train_indices = perm[:train_num]
val_indices = perm[train_num : train_num + valid_num]
test_indices = perm[train_num + valid_num :]
split_idx = {
"train": train_indices,
"valid": val_indices,
"test": test_indices,
}
# Save generated split
split_path = os.path.join(split_dir, f"{fold_n}.npz")
np.savez(split_path, **split_idx)
# Load the split
split_path = os.path.join(split_dir, f"{fold}.npz")
split_idx = np.load(split_path)
# Check that all nodes/graph have been assigned to some split
assert (
np.unique(
np.array(
split_idx["train"].tolist()
+ split_idx["valid"].tolist()
+ split_idx["test"].tolist()
)
).shape[0]
== labels.shape[0]
), "Not all nodes within splits"
return split_idx
[docs]
def assing_train_val_test_mask_to_graphs(dataset, split_idx):
r"""Split the graph dataset into train, validation, and test datasets.
Parameters
----------
dataset : torch_geometric.data.Dataset
Considered dataset.
split_idx : dict
Dictionary containing the train, validation, and test indices.
Returns
-------
list:
List containing the train, validation, and test datasets.
"""
data_train_lst, data_val_lst, data_test_lst = [], [], []
# Go over each of the graph and assign correct label
for i in range(len(dataset)):
graph = dataset[i]
assigned = False
if i in split_idx["train"]:
graph.train_mask = torch.Tensor([1]).long()
graph.val_mask = torch.Tensor([0]).long()
graph.test_mask = torch.Tensor([0]).long()
data_train_lst.append(graph)
assigned = True
if i in split_idx["valid"]:
graph.train_mask = torch.Tensor([0]).long()
graph.val_mask = torch.Tensor([1]).long()
graph.test_mask = torch.Tensor([0]).long()
data_val_lst.append(graph)
assigned = True
if i in split_idx["test"]:
graph.train_mask = torch.Tensor([0]).long()
graph.val_mask = torch.Tensor([0]).long()
graph.test_mask = torch.Tensor([1]).long()
data_test_lst.append(graph)
assigned = True
if not assigned:
raise ValueError("Graph not in any split")
return (
DataloadDataset(data_train_lst),
DataloadDataset(data_val_lst),
DataloadDataset(data_test_lst),
)
[docs]
def load_transductive_splits(dataset, parameters):
r"""Load the graph dataset with the specified split.
Parameters
----------
dataset : torch_geometric.data.Dataset
Graph dataset.
parameters : DictConfig
Configuration parameters.
Returns
-------
list:
List containing the train, validation, and test splits.
"""
# Extract labels from dataset object
assert (
len(dataset) == 1
), "Dataset should have only one graph in a transductive setting."
data = dataset.data_list[0]
labels = data.y.numpy()
# Ensure labels are one dimensional array
assert len(labels.shape) == 1, "Labels should be one dimensional array"
if parameters.split_type == "random":
splits = random_splitting(labels, parameters)
elif parameters.split_type == "k-fold":
splits = k_fold_split(labels, parameters)
else:
raise NotImplementedError(
f"split_type {parameters.split_type} not valid. Choose either 'random' or 'k-fold'"
)
# Assign train val test masks to the graph
data.train_mask = torch.from_numpy(splits["train"])
data.val_mask = torch.from_numpy(splits["valid"])
data.test_mask = torch.from_numpy(splits["test"])
if parameters.get("standardize", False):
# Standardize the node features respecting train mask
data.x = (data.x - data.x[data.train_mask].mean(0)) / data.x[
data.train_mask
].std(0)
data.y = (data.y - data.y[data.train_mask].mean(0)) / data.y[
data.train_mask
].std(0)
return DataloadDataset([data]), None, None
[docs]
def load_inductive_splits(dataset, parameters):
r"""Load multiple-graph datasets with the specified split.
Parameters
----------
dataset : torch_geometric.data.Dataset
Graph dataset.
parameters : DictConfig
Configuration parameters.
Returns
-------
list:
List containing the train, validation, and test splits.
"""
# Extract labels from dataset object
assert (
len(dataset) > 1
), "Datasets should have more than one graph in an inductive setting."
labels = np.array(
[data.y.squeeze(0).numpy() for data in dataset.data_list]
)
if parameters.split_type == "random":
split_idx = random_splitting(labels, parameters)
elif parameters.split_type == "k-fold":
split_idx = k_fold_split(labels, parameters)
elif parameters.split_type == "fixed" and hasattr(dataset, "split_idx"):
split_idx = dataset.split_idx
else:
raise NotImplementedError(
f"split_type {parameters.split_type} not valid. Choose either 'random', 'k-fold' or 'fixed'.\
If 'fixed' is chosen, the dataset should have the attribute split_idx"
)
train_dataset, val_dataset, test_dataset = (
assing_train_val_test_mask_to_graphs(dataset, split_idx)
)
return train_dataset, val_dataset, test_dataset
[docs]
def load_coauthorship_hypergraph_splits(data, parameters, train_prop=0.5):
r"""Load the split generated by rand_train_test_idx function.
Parameters
----------
data : torch_geometric.data.Data
Graph dataset.
parameters : DictConfig
Configuration parameters.
train_prop : float
Proportion of training data.
Returns
-------
torch_geometric.data.Data:
Graph dataset with the specified split.
"""
data_dir = os.path.join(
parameters["data_split_dir"], f"train_prop={train_prop}"
)
load_path = f"{data_dir}/split_{parameters['data_seed']}.npz"
splits = np.load(load_path, allow_pickle=True)
# Upload masks
data.train_mask = torch.from_numpy(splits["train"])
data.val_mask = torch.from_numpy(splits["valid"])
data.test_mask = torch.from_numpy(splits["test"])
# Check that all nodes assigned to splits
assert (
torch.unique(
torch.concat([data.train_mask, data.val_mask, data.test_mask])
).shape[0]
== data.num_nodes
), "Not all nodes within splits"
return DataloadDataset([data]), None, None