Source code for topomodelx.utils.scatter

"""Utils for scatter.

Adaptation of torch_scatter/scatter.py from:
https://github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py
"""

import torch


[docs] def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int) -> torch.Tensor: """Broadcasts `src` to the shape of `other`.""" if dim < 0: dim = other.dim() + dim if src.dim() == 1: for _ in range(dim): src = src.unsqueeze(0) for _ in range(src.dim(), other.dim()): src = src.unsqueeze(-1) return src.expand(other.size())
[docs] def scatter_sum( src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: torch.Tensor | None = None, dim_size: int | None = None, ) -> torch.Tensor: """Add all values from the `src` tensor into `out` at the indices.""" index = broadcast(index, src, dim) if out is None: size = list(src.size()) if dim_size is not None: size[dim] = dim_size elif index.numel() == 0: size[dim] = 0 else: size[dim] = int(index.max()) + 1 out = torch.zeros(size, dtype=src.dtype, device=src.device) return out.scatter_add_(dim, index, src) return out.scatter_add_(dim, index, src)
[docs] def scatter_add( src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: torch.Tensor | None = None, dim_size: int | None = None, ) -> torch.Tensor: """Add all values from the `src` tensor into `out` at the indices.""" return scatter_sum(src, index, dim, out, dim_size)
[docs] def scatter_mean( src: torch.Tensor, index: torch.Tensor, dim: int = -1, out: torch.Tensor | None = None, dim_size: int | None = None, ) -> torch.Tensor: """Compute the mean value of all values from the `src` tensor into `out`.""" out = scatter_sum(src, index, dim, out, dim_size) dim_size = out.size(dim) index_dim = dim if index_dim < 0: index_dim = index_dim + src.dim() if index.dim() <= index_dim: index_dim = index.dim() - 1 ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) count = scatter_sum(ones, index, index_dim, None, dim_size) count[count < 1] = 1 count = broadcast(count, out, dim) if out.is_floating_point(): out.true_divide_(count) else: out.div_(count, rounding_mode="floor") return out
SCATTER_DICT = {"sum": scatter_sum, "mean": scatter_mean, "add": scatter_sum}
[docs] def scatter(scatter: str): """Return the scatter function.""" if scatter not in SCATTER_DICT: raise ValueError(f"scatter must be string: {list(SCATTER_DICT.keys())}") return SCATTER_DICT[scatter]