Source code for torch_sla.partition

"""Partitioning algorithms for distributed sparse matrices.

This module hosts the data-layout side of the distributed stack:
the :class:`Partition` struct (one rank's view of which nodes it
owns + which neighbours it talks to) and the partitioner family that
builds them from a global graph -- METIS, simple 1-D slicing, RCB,
slicing-by-longest-axis, and a vectorised Hilbert space-filling
curve.

The Krylov-solver side lives in :mod:`torch_sla.distributed_solve`;
the :class:`DSparseTensor` itself lives in :mod:`torch_sla.distributed`.

External users typically just call :meth:`SparseTensor.partition_for_rank`
or :meth:`DSparseTensor.partition` and never import from this module
directly.
"""
from __future__ import annotations

import warnings
from dataclasses import dataclass
from typing import Dict, List, Tuple

import torch


[docs] @dataclass class Partition: """One rank's view of the global graph after partitioning. Attributes ---------- partition_id : int local_nodes : torch.Tensor Global indices of every node visible to this rank (owned + halo). owned_nodes : torch.Tensor Subset actually owned -- the rank's slice of the row partition. halo_nodes : torch.Tensor Ghost nodes pulled in from neighbours each halo exchange. neighbor_partitions : List[int] IDs of partitions this rank exchanges with. send_indices : Dict[int, torch.Tensor] For each neighbour, which of OUR owned nodes we send them. recv_indices : Dict[int, torch.Tensor] For each neighbour, which positions in the local layout the received data lands in. global_to_local : torch.Tensor Global → local index map; -1 (or out-of-bounds) for nodes not in this partition. local_to_global : torch.Tensor Inverse map: local index → global node id. """ partition_id: int local_nodes: torch.Tensor owned_nodes: torch.Tensor halo_nodes: torch.Tensor neighbor_partitions: List[int] send_indices: Dict[int, torch.Tensor] recv_indices: Dict[int, torch.Tensor] global_to_local: torch.Tensor local_to_global: torch.Tensor
[docs] def to(self, device: "torch.device | str") -> "Partition": """Return a copy with every tensor field moved to ``device``. Index dtypes (``int64``) are preserved. Used by :meth:`DSparseTensor.to` so halo exchange tensors live alongside the local matrix data. """ return Partition( partition_id=self.partition_id, local_nodes=self.local_nodes.to(device), owned_nodes=self.owned_nodes.to(device), halo_nodes=self.halo_nodes.to(device), neighbor_partitions=list(self.neighbor_partitions), send_indices={k: v.to(device) for k, v in self.send_indices.items()}, recv_indices={k: v.to(device) for k, v in self.recv_indices.items()}, global_to_local=self.global_to_local.to(device), local_to_global=self.local_to_global.to(device), )
[docs] def partition_graph_metis( row: torch.Tensor, col: torch.Tensor, num_nodes: int, num_parts: int, ) -> torch.Tensor: """Partition a graph using METIS (via ``pymetis``). Falls back to :func:`partition_simple` if ``pymetis`` isn't importable. Returns ------- partition_ids : torch.Tensor Partition ID per node, shape ``(num_nodes,)``. """ try: import pymetis adjacency = [[] for _ in range(num_nodes)] row_cpu = row.cpu().numpy() col_cpu = col.cpu().numpy() for r, c in zip(row_cpu, col_cpu): if r != c: adjacency[r].append(c) _, membership = pymetis.part_graph(num_parts, adjacency=adjacency) return torch.tensor(membership, dtype=torch.int64) except ImportError: warnings.warn("pymetis not available, using simple geometric partitioning") return partition_simple(num_nodes, num_parts)
[docs] def partition_simple(num_nodes: int, num_parts: int) -> torch.Tensor: """Contiguous 1-D partitioning: rank ``k`` owns the ``k``-th equal slice of ``[0, num_nodes)``. Fast, no quality guarantees.""" nodes_per_part = (num_nodes + num_parts - 1) // num_parts idx = torch.arange(num_nodes, dtype=torch.int64) return torch.clamp(idx // nodes_per_part, max=num_parts - 1)
[docs] def partition_coordinates( coords: torch.Tensor, num_parts: int, method: str = "rcb", ) -> torch.Tensor: """Geometric partitioning by node coordinates. Three methods are supported: * ``"rcb"`` -- Recursive Coordinate Bisection (median-cut along longest axis). Standard CFD/FEM partitioner. * ``"slicing"`` -- Sort along longest axis, slice into equal chunks. Fast but worse quality than RCB. * ``"hilbert"`` -- Sort by Hilbert space-filling-curve index, slice into equal chunks. ~10-100x faster than METIS for PDE/mesh graphs where geometric locality correlates with sparse adjacency. 2-D and 3-D only. Parameters ---------- coords : torch.Tensor Node coordinates, shape ``(num_nodes, dim)``. num_parts : int Number of partitions (power of 2 recommended for RCB). method : str ``"rcb"`` / ``"slicing"`` / ``"hilbert"``. """ num_nodes = coords.size(0) partition_ids = torch.zeros(num_nodes, dtype=torch.int64) if method == "rcb": _rcb_partition(coords, partition_ids, torch.arange(num_nodes), 0, num_parts) elif method == "hilbert": sorted_idx = _hilbert_sort_indices(coords) nodes_per_part = (num_nodes + num_parts - 1) // num_parts for i, idx in enumerate(sorted_idx.tolist()): partition_ids[idx] = min(i // nodes_per_part, num_parts - 1) else: # slicing ranges = coords.max(0).values - coords.min(0).values axis = ranges.argmax().item() sorted_idx = coords[:, axis].argsort() nodes_per_part = (num_nodes + num_parts - 1) // num_parts for i, idx in enumerate(sorted_idx): partition_ids[idx] = min(i // nodes_per_part, num_parts - 1) return partition_ids
def _hilbert_curve_indices(coords: torch.Tensor, order: int = 16) -> torch.Tensor: """Map each row of ``coords`` to its position along a 2-D / 3-D Hilbert space-filling curve. Vectorised Skilling 2003 ``AxesToTranspose`` / ``TransposeToHilbert`` bit-twiddling: the per-point loop is folded into a torch tensor axis so every bit op runs on all ``N`` points at once. The remaining loops are over ``d`` (=2/3) and ``order`` (=16), both tiny. ``order`` is the bit-depth per axis (default 16 → up to ~65k cells per dim before precision becomes the bottleneck). Returns the raw Hilbert index per row (shape ``(N,)``, dtype int64). """ if coords.dim() != 2: raise ValueError(f"coords must be 2-D, got shape {tuple(coords.shape)}") n, d = coords.shape if d not in (2, 3): raise ValueError( f"Hilbert partitioner supports 2-D / 3-D coords, got d={d}") if coords.dtype.is_floating_point: c_min = coords.min(0).values c_max = coords.max(0).values extent = (c_max - c_min).clamp_min(1e-30) scale = (1 << order) - 1 x = ((coords - c_min) / extent * scale).clamp_(0, scale).long() else: x = coords.long().clone() # Skilling 2003 "inverse undo" pass q = 1 << (order - 1) while q > 1: p = q - 1 for i in range(d): mask = (x[:, i] & q) != 0 t = (x[:, 0] ^ x[:, i]) & p x_zero_true = x[:, 0] ^ p x_zero_false = x[:, 0] ^ t x_i_false = x[:, i] ^ t x[:, 0] = torch.where(mask, x_zero_true, x_zero_false) x[:, i] = torch.where(mask, x[:, i], x_i_false) q >>= 1 # Gray-encode pass for i in range(1, d): x[:, i] ^= x[:, i - 1] # Untangle pass t = torch.zeros(n, dtype=torch.int64) q = 1 << (order - 1) while q > 1: mask = (x[:, d - 1] & q) != 0 t ^= torch.where(mask, torch.full_like(t, q - 1), torch.zeros_like(t)) q >>= 1 x ^= t.unsqueeze(1) # Bit-interleave to build the Hilbert index, MSB → LSB. h = torch.zeros(n, dtype=torch.int64) for bit in range(order - 1, -1, -1): for i in range(d): h = (h << 1) | ((x[:, i] >> bit) & 1) return h def _hilbert_sort_indices(coords: torch.Tensor, order: int = 16) -> torch.Tensor: """Return node indices sorted by Hilbert-curve position. Sorting by Hilbert index gives blocks with strong geometric locality -- exactly what sparse matvec wants for cheap halo exchange on PDE matrices.""" return _hilbert_curve_indices(coords, order=order).argsort() def _rcb_partition( coords: torch.Tensor, partition_ids: torch.Tensor, node_indices: torch.Tensor, part_offset: int, num_parts: int, ) -> None: """Recursive Coordinate Bisection helper. Mutates ``partition_ids`` in place.""" if num_parts == 1 or len(node_indices) == 0: partition_ids[node_indices] = part_offset return local_coords = coords[node_indices] ranges = local_coords.max(0).values - local_coords.min(0).values axis = ranges.argmax().item() axis_vals = local_coords[:, axis] median = axis_vals.median() left_mask = axis_vals <= median right_mask = ~left_mask left_nodes = node_indices[left_mask] right_nodes = node_indices[right_mask] left_parts = num_parts // 2 right_parts = num_parts - left_parts _rcb_partition(coords, partition_ids, left_nodes, part_offset, left_parts) _rcb_partition(coords, partition_ids, right_nodes, part_offset + left_parts, right_parts) def resolve_partition_ids( row: torch.Tensor, col: torch.Tensor, num_nodes: int, num_parts: int, method: str, coords: torch.Tensor = None, ) -> torch.Tensor: """Dispatch to the right partitioner by string name and return the per-node partition-id tensor of shape ``(num_nodes,)``. ``method`` may be ``"auto"`` (uses ``"rcb"`` if ``coords`` is given, else ``"simple"`` -- deterministic for distributed setups) / ``"simple"`` / ``"metis"`` / ``"rcb"`` / ``"slicing"`` / ``"hilbert"``. Geometric methods need ``coords``. """ if method == "auto": method = "rcb" if coords is not None else "simple" if method == "simple": return partition_simple(num_nodes, num_parts) if method == "metis": return partition_graph_metis(row, col, num_nodes, num_parts) if method in ("rcb", "slicing", "hilbert"): if coords is None: raise ValueError(f"partition method '{method}' requires coords") return partition_coordinates(coords, num_parts, method=method) raise ValueError( f"Unknown partition method {method!r}; expected one of " "auto / simple / metis / rcb / slicing / hilbert.") def build_partition( row: torch.Tensor, col: torch.Tensor, num_nodes: int, partition_ids: torch.Tensor, my_partition: int, ) -> Partition: """Build a full :class:`Partition` struct for one rank from a partition-id assignment over the global graph. Discovers owned / halo nodes, computes the global↔local index map, and lays out send / recv buffers for halo exchange. Used by :meth:`DSparseTensor.partition` and :meth:`DSparseTensor.from_global_distributed` to scatter a global graph across the mesh without ever materialising the dense layout. """ owned_mask = partition_ids == my_partition owned_nodes = owned_mask.nonzero().squeeze(-1) halo_nodes, send_map = find_halo_nodes(row, col, partition_ids, my_partition) local_nodes = torch.cat([owned_nodes, halo_nodes]) num_local = len(local_nodes) # Global → local index map (vectorised). ``-1`` marks "not on this rank". global_to_local = torch.full((num_nodes,), -1, dtype=torch.int64) global_to_local[local_nodes] = torch.arange(num_local, dtype=torch.int64) # Map each neighbour's send back to local recv positions. halo_offset = len(owned_nodes) halo_to_local = torch.full((num_nodes,), -1, dtype=torch.int64) halo_to_local[halo_nodes] = torch.arange(len(halo_nodes), dtype=torch.int64) + halo_offset recv_indices: Dict[int, torch.Tensor] = {} for neighbor_id in send_map.keys(): neighbor_owned = (partition_ids == neighbor_id).nonzero().squeeze(-1) local_idx = halo_to_local[neighbor_owned] recv_indices[neighbor_id] = local_idx[local_idx >= 0] # send_map stores global node ids; halo_exchange wants local indices. send_indices_local: Dict[int, torch.Tensor] = {} for neighbor_id, global_nodes in send_map.items(): send_indices_local[neighbor_id] = global_to_local[global_nodes] return Partition( partition_id=my_partition, local_nodes=local_nodes, owned_nodes=owned_nodes, halo_nodes=halo_nodes, neighbor_partitions=list(send_map.keys()), send_indices=send_indices_local, recv_indices=recv_indices, global_to_local=global_to_local, local_to_global=local_nodes.clone(), ) def find_halo_nodes( row: torch.Tensor, col: torch.Tensor, partition_ids: torch.Tensor, partition_id: int, ) -> Tuple[torch.Tensor, Dict[int, torch.Tensor]]: """Find halo (ghost) nodes for a partition. A halo node is owned by some other partition but referenced by one of our owned rows -- we need its value to do local SpMV. Returns ------- halo_nodes : torch.Tensor Global indices of nodes this rank pulls in each exchange. send_map : Dict[int, torch.Tensor] For each neighbour partition, the global indices of our owned nodes that they need from us. """ owned_mask = partition_ids == partition_id row_cpu = row.cpu() col_cpu = col.cpu() row_owned = owned_mask[row_cpu] col_owned = owned_mask[col_cpu] # Case 1: row owned, col not owned -> col is halo mask1 = row_owned & ~col_owned halo_from_col = col_cpu[mask1] send_to_neighbor_col = row_cpu[mask1] neighbor_ids_col = partition_ids[halo_from_col] # Case 2: col owned, row not owned -> row is halo mask2 = col_owned & ~row_owned halo_from_row = row_cpu[mask2] send_to_neighbor_row = col_cpu[mask2] neighbor_ids_row = partition_ids[halo_from_row] all_halo = torch.cat([halo_from_col, halo_from_row]) halo_nodes = torch.unique(all_halo, sorted=True) all_neighbors = torch.cat([neighbor_ids_col, neighbor_ids_row]) all_send_nodes = torch.cat([send_to_neighbor_col, send_to_neighbor_row]) send_map: Dict[int, torch.Tensor] = {} unique_neighbors = torch.unique(all_neighbors) for neighbor_id in unique_neighbors.tolist(): mask = all_neighbors == neighbor_id send_map[neighbor_id] = torch.unique(all_send_nodes[mask], sorted=True) return halo_nodes, send_map