Source code for topomodelx.utils.scatter
"""Utils for scatter.
Adaptation of torch_scatter/scatter.py from:
https://github.com/rusty1s/pytorch_scatter/blob/master/torch_scatter/scatter.py
"""
import torch
[docs]
def broadcast(src: torch.Tensor, other: torch.Tensor, dim: int) -> torch.Tensor:
"""Broadcasts `src` to the shape of `other`."""
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
return src.expand(other.size())
[docs]
def scatter_sum(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: torch.Tensor | None = None,
dim_size: int | None = None,
) -> torch.Tensor:
"""Add all values from the `src` tensor into `out` at the indices."""
index = broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
return out.scatter_add_(dim, index, src)
[docs]
def scatter_add(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: torch.Tensor | None = None,
dim_size: int | None = None,
) -> torch.Tensor:
"""Add all values from the `src` tensor into `out` at the indices."""
return scatter_sum(src, index, dim, out, dim_size)
[docs]
def scatter_mean(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: torch.Tensor | None = None,
dim_size: int | None = None,
) -> torch.Tensor:
"""Compute the mean value of all values from the `src` tensor into `out`."""
out = scatter_sum(src, index, dim, out, dim_size)
dim_size = out.size(dim)
index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= index_dim:
index_dim = index.dim() - 1
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter_sum(ones, index, index_dim, None, dim_size)
count[count < 1] = 1
count = broadcast(count, out, dim)
if out.is_floating_point():
out.true_divide_(count)
else:
out.div_(count, rounding_mode="floor")
return out
SCATTER_DICT = {"sum": scatter_sum, "mean": scatter_mean, "add": scatter_sum}
[docs]
def scatter(scatter: str):
"""Return the scatter function."""
if scatter not in SCATTER_DICT:
raise ValueError(f"scatter must be string: {list(SCATTER_DICT.keys())}")
return SCATTER_DICT[scatter]