Source code for torch_sla.distributed.core

"""
Distributed sparse matrix for large-scale CFD / FEM / GNN computations.

:class:`DSparseTensor` mirrors :class:`torch.distributed.tensor.DTensor`:
each rank holds a local :class:`~torch_sla.sparse_tensor.SparseTensor`
chunk plus a :class:`~torch_sla.partition.Partition` map (owned rows +
halo) and the matvec stays entirely in the ``Shard(0)`` space via halo
exchange + local SpMV.

::

    from torch_sla import SparseTensor, DSparseTensor, solve, SolverConfig
    from torch.distributed.device_mesh import init_device_mesh

    A    = SparseTensor(val, row, col, shape)
    mesh = init_device_mesh("cpu", (world_size,))
    D    = DSparseTensor.partition(A, mesh, partition_method="metis")
    b_dt = D.scatter(b_global)

    with SolverConfig(method="cg", atol=1e-10, rtol=1e-10, maxiter=2000):
        x_dt = solve(D, b_dt)

The Krylov methods (CG / BiCGStab / GMRES / FGMRES / MINRES) and
preconditioners (Jacobi / block-Jacobi / SSOR / polynomial) live in
:mod:`torch_sla.distributed_solve`; partitioning lives in
:mod:`torch_sla.partition`.
"""

import os
import torch
from typing import Any, Tuple, List, Dict, Optional, Union, Literal
from dataclasses import dataclass
import warnings

from ..backends import (
    is_scipy_available,
    is_eigen_available,
    is_cupy_available,
    is_cudss_available,
    select_backend,
    select_method,
    BackendType,
    MethodType,
)

try:
    import torch.distributed as dist
    DIST_AVAILABLE = True
except ImportError:
    DIST_AVAILABLE = False

# DTensor support (PyTorch 2.0+). On torch >=2.2 these live under
# ``torch.distributed.tensor``; torch 2.0-2.1 still keeps them under the
# private ``_tensor`` namespace. Centralise the version check here so
# every runtime import below can use the same names.
try:
    from torch.distributed.tensor import DTensor
    from torch.distributed.tensor.placement_types import Shard, Replicate
    DTENSOR_AVAILABLE = True
    _dtensor_module = "torch.distributed.tensor"
except ImportError:
    try:
        from torch.distributed._tensor import DTensor
        from torch.distributed._tensor.placement_types import Shard, Replicate
        DTENSOR_AVAILABLE = True
        _dtensor_module = "torch.distributed._tensor"
    except ImportError:
        DTENSOR_AVAILABLE = False
        DTensor = None
        Shard = None
        Replicate = None
        _dtensor_module = None


def _is_dtensor(x) -> bool:
    """Check if x is a DTensor instance."""
    if not DTENSOR_AVAILABLE or DTensor is None:
        return False
    return isinstance(x, DTensor)


# Partition struct + partitioning algorithms (METIS / simple / RCB /
# slicing / Hilbert) + halo discovery live in :mod:`torch_sla.partition`
# now. Re-exported here so existing ``from torch_sla.distributed import
# Partition, partition_simple, ...`` call sites keep working.
from ..partition import (
    Partition,
    partition_graph_metis,
    partition_simple,
    partition_coordinates,
    _hilbert_curve_indices,
    _hilbert_sort_indices,
    _rcb_partition,
    find_halo_nodes,
    build_partition,
    resolve_partition_ids,
)


# ====================================================================== #
# DTensor-mirror placement vocabulary for sparse tensors.
#
# Adapted from ``torch.distributed.tensor.placement_types``:
#
# * :class:`Replicated`        --  every rank holds the full matrix
#                                  (analogous to ``Replicate()``).
# * :class:`RowPartitioned`    --  rows are split across ranks via an
#                                  irregular METIS / RCB / simple map
#                                  (the sparse analog of ``Shard(0)``;
#                                  *not* the uniform DTensor shard).
#
# A :class:`DSparseSpec` bundles a placement with the device mesh and
# global shape so the rest of the API can mirror DTensor's
# ``DTensor._spec``.
# ====================================================================== #
@dataclass(frozen=True)
class Replicated:
    """DSparseTensor placement: every rank holds the entire matrix."""
    pass


@dataclass(frozen=True)
class VertexShard:
    """DSparseTensor placement: METIS-style vertex partition with
    **row-storage** local layout.

    Each rank holds ``A[owned_vertices, local_to_global]`` -- the rows
    are restricted to its owned vertex set and the columns span owned +
    halo. Matvec output is ``DTensor[Shard(0)]`` (each rank owns the
    owned-row slice of y); no all-reduce required.

    The ``partition`` field carries the irregular per-rank vertex map
    (``owned_nodes`` / ``halo_nodes`` / ``neighbor_partitions``) that
    METIS / Hilbert / RCB produce. ``None`` only when an empty marker
    is being passed around as a type tag.

    This is the **default** placement -- every Krylov solver, eigsh,
    and persistence path uses it.
    """
    partition: Optional["Partition"] = None

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, VertexShard):
            return False
        my_pid = self.partition.partition_id if self.partition else None
        ot_pid = other.partition.partition_id if other.partition else None
        return my_pid == ot_pid

    def __hash__(self) -> int:
        pid = self.partition.partition_id if self.partition else -1
        return hash(("VertexShard", pid))


@dataclass(frozen=True)
class VertexShardReplicated:
    """Same vertex partition as :class:`VertexShard` but with
    **col-storage** local layout.

    Each rank holds ``A[local_to_global, owned_vertices]`` -- partial
    matvec products end up Replicated after ``all_reduce(SUM)``.
    Specialised path; almost no production code needs this -- only
    transpose-heavy algorithms (normal equations, certain autograd
    paths) benefit. Not yet implemented end-to-end.
    """
    partition: Optional["Partition"] = None

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, VertexShardReplicated):
            return False
        my_pid = self.partition.partition_id if self.partition else None
        ot_pid = other.partition.partition_id if other.partition else None
        return my_pid == ot_pid

    def __hash__(self) -> int:
        pid = self.partition.partition_id if self.partition else -1
        return hash(("VertexShardReplicated", pid))


# Deprecated alias for back-compat. Callers should switch to
# :class:`VertexShard` / :class:`VertexShardReplicated` directly.
def SparseShard(axis: int = 0, partition: Optional["Partition"] = None):
    """Deprecated. Use :class:`VertexShard` or :class:`VertexShardReplicated`."""
    import warnings
    warnings.warn(
        "SparseShard(axis=...) is deprecated; use VertexShard() (axis=0) or "
        "VertexShardReplicated() (axis=1) directly. The `axis` parameter was "
        "misleading -- partition is over the vertex set, axis only selected "
        "the local data layout.",
        DeprecationWarning, stacklevel=2,
    )
    if axis == 0:
        return VertexShard(partition=partition)
    if axis == 1:
        return VertexShardReplicated(partition=partition)
    raise ValueError(f"axis must be 0 or 1, got {axis}")


@dataclass(frozen=True)
class BatchShard:
    """DSparseTensor placement: shard a **batch** axis (not a sparse axis).

    For a SparseTensor of shape ``(*batch, M, N, *block)`` with
    ``BatchShard(axis=k)``, rank ``r`` holds the contiguous slice of
    the k-th batch axis ``batch[k][r*chunk:(r+1)*chunk]`` (with the
    last rank picking up any tail). The sparse pattern -- row and col
    indices -- is **replicated** on every rank; only ``values`` is
    sharded.

    Matvec is embarrassingly parallel: each rank computes its own
    batch slice with zero inter-rank communication. Cross-batch
    reductions (``sum`` / ``norm`` / ``mean`` over the sharded axis)
    use a single ``all_reduce(SUM)``.
    """
    axis: int = 0
    chunk: int = 0       # size of this rank's slice
    start: int = 0       # first batch index this rank owns
    end: int = 0         # one past the last batch index this rank owns
    global_size: int = 0  # full extent of the sharded batch axis

    def __eq__(self, other: object) -> bool:
        return (isinstance(other, BatchShard) and self.axis == other.axis
                and self.start == other.start and self.end == other.end)

    def __hash__(self) -> int:
        return hash((type(self).__name__, self.axis, self.start, self.end))


SparsePlacement = Union[Replicated, VertexShard, VertexShardReplicated, BatchShard]
# Tuple form for ``isinstance(p, _VERTEX_SHARDS)`` checks inside dispatch.
_VERTEX_SHARDS = (VertexShard, VertexShardReplicated)


@dataclass(frozen=True)
class DSparseSpec:
    """The sparse analog of :class:`torch.distributed.tensor.DTensorSpec`.

    Bundles the placement, the device mesh, and the global shape so we
    can dispatch operations purely off the spec and treat
    :class:`DSparseTensor` like any other distributed tensor.
    """
    placement: SparsePlacement
    mesh: Any                            # torch.distributed.DeviceMesh
    global_shape: Tuple[int, int]


[docs] class DSparseTensor: """ Distributed Sparse Tensor with automatic partitioning and halo exchange. A Pythonic wrapper that provides a unified interface for distributed sparse matrix operations. Supports indexing to access individual partitions. Parameters ---------- values : torch.Tensor Non-zero values [nnz] row_indices : torch.Tensor Row indices [nnz] col_indices : torch.Tensor Column indices [nnz] shape : Tuple[int, int] Matrix shape (m, n) num_partitions : int Number of partitions to create coords : torch.Tensor, optional Node coordinates for geometric partitioning [num_nodes, dim] partition_method : str Partitioning method: 'metis', 'rcb', 'slicing', 'simple' device : str or torch.device Device for the matrix data verbose : bool Whether to print partition info Example ------- >>> import torch >>> from torch_sla import DSparseTensor >>> >>> # Create distributed tensor with 4 partitions >>> A = DSparseTensor(val, row, col, shape, num_partitions=4) >>> >>> # Access individual partitions >>> A0 = A[0] # First partition >>> A1 = A[1] # Second partition >>> >>> # Iterate over partitions >>> for partition in A: >>> x = partition.solve(b_local) >>> >>> # Properties >>> print(A.num_partitions) # 4 >>> print(A.shape) # Global shape >>> print(len(A)) # 4 >>> >>> # Move to CUDA >>> A_cuda = A.cuda() >>> >>> # Local halo exchange (for testing) >>> x_list = [torch.zeros(A[i].num_local) for i in range(4)] >>> A.halo_exchange_local(x_list) """ def __init__(self) -> None: """Direct instantiation isn't supported -- use one of the classmethod constructors: * :meth:`partition` -- global :class:`SparseTensor` + mesh → row-sharded :class:`DSparseTensor`. * :meth:`from_global_distributed` -- global COO + rank/world → row-sharded :class:`DSparseTensor` (broadcasts partition ids from rank 0 for determinism). * :meth:`from_sparse_local` -- per-rank ``(SparseTensor, Partition)`` → :class:`DSparseTensor`. Each populates ``_local_tensor`` (the per-rank SparseTensor backing) and ``_spec`` (the placement + mesh + global shape). """ raise TypeError( "DSparseTensor() does not support direct instantiation. Use " "DSparseTensor.partition(A, mesh) / " "DSparseTensor.from_global_distributed(...) / " "DSparseTensor.from_sparse_local(...)." ) # ====================================================================== # # DTensor-mirror API: from_local / to_local / full_tensor / redistribute. # # These methods give DSparseTensor the same shape of API as # ``torch.distributed.tensor.DTensor``: every call resolves through # a private :class:`DSparseSpec` that bundles the placement, the # device mesh, and the global shape. Vectors crossing the API stay # as ``DTensor[Shard(0)]`` so the rest of the PyTorch distributed # ecosystem (FSDP, TP, DCP) composes for free. # ====================================================================== #
[docs] @classmethod def from_sparse_local( cls, local_tensor: "SparseTensor", mesh: Any, partition: "Partition", *, axis: int = 0, global_shape: Optional[Tuple[int, int]] = None, ) -> "DSparseTensor": """Wrap a per-rank :class:`SparseTensor` chunk (already in local coords) plus its :class:`Partition` as a DSparseTensor. Use together with :meth:`SparseTensor.extract_partition`: .. code-block:: python partition = compute_partition(...) local_tensor = A_global.extract_partition(partition) D = DSparseTensor.from_sparse_local( local_tensor, mesh, partition, global_shape=A_global.shape, ) y_dt = D @ x_dt # halo exchange + local SpMV The partition is stamped onto ``_spec.placement.partition`` so the placement is the single source of truth for the irregular shard map. Parameters ---------- local_tensor : SparseTensor This rank's local subdomain (size ``(num_local, num_local)``, COO in local coordinates). Usually built by :meth:`SparseTensor.extract_partition`. mesh : DeviceMesh The PyTorch device mesh. partition : Partition Irregular partition map for this rank (``owned_nodes`` / ``halo_nodes`` / ``neighbor_partitions`` etc). axis : int Sparse axis being sharded (default 0 = rows). global_shape : Tuple[int, int], optional Global matrix shape. If omitted, inferred from ``partition.local_to_global.numel()`` -- only valid for square matrices. """ if global_shape is None: n = int(partition.local_to_global.numel() + 0) # placeholder; caller should pass it explicitly global_shape = (n, n) self = cls.__new__(cls) self._values = None self._row_indices = None self._col_indices = None self._shape = global_shape self._num_partitions = mesh.size() if mesh is not None else 1 self._coords = None self._partition_method = None self._verbose = False self._device = local_tensor.values.device self._local_tensor = local_tensor self._halo_send_buffers = {} self._halo_recv_buffers = {} if axis == 0: placement = VertexShard(partition=partition) elif axis == 1: placement = VertexShardReplicated(partition=partition) else: raise ValueError(f"axis must be 0 or 1, got {axis}") self._spec = DSparseSpec(placement=placement, mesh=mesh, global_shape=global_shape) return self
[docs] @classmethod def partition_batch( cls, A: "SparseTensor", mesh: Any, *, axis: int = 0, ) -> "DSparseTensor": """Batch-shard a batched :class:`SparseTensor` across ``mesh``. Every rank gets the same row/col indices; only the values tensor is sliced along ``A.batch_shape[axis]``. No halo exchange, no cross-rank comm in matvec. Requires ``A.is_batched`` and ``axis < len(A.batch_shape)``. """ if not A.is_batched: raise ValueError("partition_batch requires a batched SparseTensor") if axis < 0 or axis >= len(A.batch_shape): raise ValueError( f"axis {axis} out of range for batch_shape {A.batch_shape}") try: import torch.distributed as dist rank = dist.get_rank() if dist.is_initialized() else 0 world = dist.get_world_size() if dist.is_initialized() else 1 except (RuntimeError, ImportError): rank, world = 0, 1 B = int(A.batch_shape[axis]) chunk = (B + world - 1) // world start = min(rank * chunk, B) end = min(start + chunk, B) my_size = end - start # Slice values along the sharded batch axis. SparseTensor's # ``values`` has shape ``[*batch, nnz, *block]`` so the batch # axis position matches ``axis``. new_values = A.values.narrow(axis, start, my_size) # Sub-tensor's batch_shape replaces the sharded extent. new_shape = list(A.shape) new_shape[axis] = my_size from ..sparse_tensor import SparseTensor local_st = SparseTensor(new_values, A.row_indices, A.col_indices, shape=tuple(new_shape), sparse_dim=A.sparse_dim) placement = BatchShard(axis=axis, chunk=chunk, start=start, end=end, global_size=B) self = cls.__new__(cls) self._values = None self._row_indices = None self._col_indices = None self._shape = tuple(A.shape) self._num_partitions = world self._coords = None self._partition_method = None self._verbose = False self._device = local_st.values.device self._local_tensor = local_st self._halo_send_buffers = {} self._halo_recv_buffers = {} self._spec = DSparseSpec(placement=placement, mesh=mesh, global_shape=tuple(A.shape)) return self
[docs] @classmethod def partition( cls, A: "SparseTensor", mesh: Any, *, partition_method: str = "simple", coords: Optional[torch.Tensor] = None, verbose: bool = False, ) -> "DSparseTensor": """One-shot constructor: take a global :class:`SparseTensor` + :class:`DeviceMesh`, partition rows across the mesh, return a ready-to-use distributed tensor with :class:`RowPartitioned` placement. Equivalent to:: local = A.partition_for_rank(rank, world_size, partition_method=partition_method, coords=coords) D = DSparseTensor.from_local(local, mesh, placement=RowPartitioned()) but in one line. This is the recommended way to build a distributed sparse tensor from a global :class:`SparseTensor` for both unit tests and small-to-medium production runs (where every rank can afford to hold the global ``A`` briefly). For memory-tight scenarios where only rank 0 should ever materialise the global matrix, use :meth:`from_global_distributed` (which broadcasts only the partition IDs from rank 0) and chain :meth:`from_local` manually. Parameters ---------- A : SparseTensor Global sparse matrix; every rank should hold an identical copy at the time of the call. mesh : DeviceMesh Target device mesh. ``mesh.size()`` becomes the world size and ``dist.get_rank()`` picks this rank's chunk. partition_method : str Partitioning algorithm passed through to :meth:`SparseTensor.partition_for_rank`: ``"simple"`` / ``"metis"`` / ``"rcb"`` / ``"slicing"``. coords : torch.Tensor, optional Node coordinates for geometric partitioning (RCB/slicing). verbose : bool Print partition info on each rank. """ if DIST_AVAILABLE and dist.is_initialized(): rank = dist.get_rank() else: rank = 0 world_size = mesh.size() if mesh is not None else 1 # Compute partition ids, build the Partition struct, extract # this rank's local SparseTensor, wrap. # # Determinism: when a process group is active and world_size>1, # let rank 0 compute the partition ids and broadcast the result # to every other rank. Some partitioners (notably parallel # METIS variants) seed their RNG from a thread- or # process-local source and can yield different labellings on # different ranks. If that drift happens silently, the owned / # halo bookkeeping disagrees across ranks, halo exchanges land # in the wrong slot, and distributed CG converges to a wrong # answer with no error raised. Mirrors the pattern in # :meth:`from_global_distributed`. if (DIST_AVAILABLE and dist.is_initialized() and world_size > 1): # Pick the broadcast device by backend: NCCL refuses CPU # tensors ("No backend type associated with device type # cpu"), Gloo prefers CPU. Match the active backend and # move the result back to CPU afterwards so downstream # code (build_partition / extract_partition) sees a plain # CPU LongTensor regardless of backend. backend = dist.get_backend() bcast_device = ( torch.device("cuda", torch.cuda.current_device()) if backend == "nccl" else torch.device("cpu") ) n_rows = int(A.shape[0]) if rank == 0: partition_ids = resolve_partition_ids( A.row_indices, A.col_indices, n_rows, world_size, method=partition_method, coords=coords, ).to(bcast_device) else: partition_ids = torch.zeros(n_rows, dtype=torch.int64, device=bcast_device) dist.broadcast(partition_ids, src=0) partition_ids = partition_ids.cpu() else: partition_ids = resolve_partition_ids( A.row_indices, A.col_indices, int(A.shape[0]), world_size, method=partition_method, coords=coords, ) partition = build_partition( A.row_indices, A.col_indices, int(A.shape[0]), partition_ids, rank, ) local_st = A.extract_partition(partition) return cls.from_sparse_local( local_st, mesh, partition, global_shape=tuple(A.shape), )
@property def spec(self) -> Optional[DSparseSpec]: """The :class:`DSparseSpec` for this tensor (placement + mesh + global shape), or ``None`` if this instance was built via the legacy single-process simulator constructor.""" return self._spec
[docs] def scatter(self, global_vec: torch.Tensor) -> "DTensor": """Convenience: extract this rank's owned slice from a global vector and wrap as a ``DTensor[Shard(0)]``. Common usage:: b_dt = D.scatter(b_global) # build distributed RHS x_dt = solve(D, b_dt) # distributed solve r_dt = b_dt - D @ x_dt # distributed residual ``global_vec`` is a 1-D ``torch.Tensor`` of size ``global_shape[0]``. Every rank should hold the same copy (typical in tests; in production the caller loads on rank 0 and broadcasts). """ partition = self._partition_for_dispatch() if partition is None: raise RuntimeError( "scatter() requires a partition map -- build this " "DSparseTensor via .partition(...) or .from_local(...)") owned = partition.owned_nodes.to(device=global_vec.device, dtype=torch.int64) local_slice = global_vec[owned].contiguous() _DTensor = DTensor # use module-level import with fallback return _DTensor.from_local(local_slice, self._spec.mesh, [Shard(0)])
def _partition_for_dispatch(self) -> Optional["Partition"]: """Return the active :class:`Partition` from the spec, or ``None`` if no spec is set.""" if self._spec is not None and isinstance( self._spec.placement, _VERTEX_SHARDS) \ and self._spec.placement.partition is not None: return self._spec.placement.partition return None
[docs] def full_tensor(self) -> "SparseTensor": """Materialise the full global tensor on every rank. Mirrors :meth:`DTensor.full_tensor`. For :class:`SparseShard(axis=0)` we drop halo rows, translate indices to global, and allgather the COO triples. For :class:`BatchShard` we allgather the per-rank values slices along the sharded batch axis. """ from ..sparse_tensor import SparseTensor if self._spec is None: raise RuntimeError("DSparseTensor.full_tensor() requires a spec") if isinstance(self._spec.placement, BatchShard): placement = self._spec.placement local_vals = self._local_tensor.values.contiguous() if not (DIST_AVAILABLE and dist.is_initialized()): full_vals = local_vals else: world = dist.get_world_size() sizes = torch.tensor([local_vals.shape[placement.axis]], dtype=torch.long, device=local_vals.device) all_sizes = [torch.zeros_like(sizes) for _ in range(world)] dist.all_gather(all_sizes, sizes) sizes_l = [int(s.item()) for s in all_sizes] max_size = max(sizes_l) pad_n = max_size - local_vals.shape[placement.axis] if pad_n > 0: pad_shape = list(local_vals.shape) pad_shape[placement.axis] = pad_n pad = torch.zeros(pad_shape, dtype=local_vals.dtype, device=local_vals.device) padded = torch.cat([local_vals, pad], dim=placement.axis) else: padded = local_vals gathered = [torch.zeros_like(padded) for _ in range(world)] dist.all_gather(gathered, padded) slices = [g.narrow(placement.axis, 0, sz) for g, sz in zip(gathered, sizes_l)] full_vals = torch.cat(slices, dim=placement.axis) return SparseTensor( full_vals, self._local_tensor.row_indices, self._local_tensor.col_indices, shape=self._spec.global_shape, sparse_dim=self._local_tensor.sparse_dim, ) partition = self._partition_for_dispatch() if partition is None: raise RuntimeError( "DSparseTensor.full_tensor() requires a SparseShard " "placement with a Partition.") st = self._local_tensor if st is None: raise RuntimeError( "DSparseTensor.full_tensor() requires a SparseTensor " "backing.") # Drop halo rows -- only owned rows participate in the global # matrix. Local row indices < num_owned are the owned ones. device = st.values.device num_owned = int(partition.owned_nodes.numel()) owned_mask = st.row_indices < num_owned local_rows = st.row_indices[owned_mask] local_cols = st.col_indices[owned_mask] local_vals = st.values[owned_mask] # Translate local row / col → global indices. l2g = partition.local_to_global.to(device=device, dtype=torch.int64) global_rows = l2g[local_rows] global_cols = l2g[local_cols] if not (DIST_AVAILABLE and dist.is_initialized()): return SparseTensor(local_vals, global_rows, global_cols, tuple(self._spec.global_shape)) # All-gather the per-rank triples across the mesh. world_size = dist.get_world_size() nnz_t = torch.tensor([int(global_rows.numel())], device=device, dtype=torch.int64) all_nnz = [torch.zeros(1, device=device, dtype=torch.int64) for _ in range(world_size)] dist.all_gather(all_nnz, nnz_t) sizes = [int(t.item()) for t in all_nnz] max_nnz = max(sizes) def _padded(t, dtype): out = torch.zeros(max_nnz, device=device, dtype=dtype) out[:t.numel()] = t.to(dtype=dtype) return out val_pad = _padded(local_vals, local_vals.dtype) row_pad = _padded(global_rows, torch.int64) col_pad = _padded(global_cols, torch.int64) all_vals = [torch.zeros_like(val_pad) for _ in range(world_size)] all_rows = [torch.zeros_like(row_pad) for _ in range(world_size)] all_cols = [torch.zeros_like(col_pad) for _ in range(world_size)] dist.all_gather(all_vals, val_pad) dist.all_gather(all_rows, row_pad) dist.all_gather(all_cols, col_pad) out_vals = torch.cat([all_vals[r][:sizes[r]] for r in range(world_size)]) out_rows = torch.cat([all_rows[r][:sizes[r]] for r in range(world_size)]) out_cols = torch.cat([all_cols[r][:sizes[r]] for r in range(world_size)]) return SparseTensor(out_vals, out_rows, out_cols, tuple(self._spec.global_shape))
[docs] @classmethod def from_global_distributed( cls, values: torch.Tensor, row_indices: torch.Tensor, col_indices: torch.Tensor, shape: Tuple[int, int], rank: int, world_size: int, mesh: Any = None, coords: Optional[torch.Tensor] = None, partition_method: str = 'auto', device: Optional[Union[str, torch.device]] = None, verbose: bool = True ) -> "DSparseTensor": """ Create local partition in a distributed-safe manner. This method ensures that all ranks compute the same partition assignment by having rank 0 compute the partition IDs and broadcasting to all ranks. Parameters ---------- values : torch.Tensor Global non-zero values [nnz] row_indices : torch.Tensor Global row indices [nnz] col_indices : torch.Tensor Global column indices [nnz] shape : Tuple[int, int] Global matrix shape (M, N) rank : int Current process rank world_size : int Total number of processes coords : torch.Tensor, optional Node coordinates for geometric partitioning [num_nodes, dim] partition_method : str Partitioning method: 'metis', 'rcb', 'slicing', 'simple' device : str or torch.device, optional Target device verbose : bool Whether to print partition info Returns ------- DSparseTensor This rank's row-sharded distributed tensor. Example ------- >>> import torch.distributed as dist >>> >>> # In each process: >>> rank = dist.get_rank() >>> world_size = dist.get_world_size() >>> >>> local_matrix = DSparseTensor.from_global_distributed( ... val, row, col, shape, ... rank=rank, world_size=world_size ... ) """ import torch.distributed as dist if device is None: device = values.device if isinstance(device, str): device = torch.device(device) # Compute partition IDs on rank 0 and broadcast for determinism. if rank == 0: partition_ids = resolve_partition_ids( row_indices, col_indices, int(shape[0]), world_size, method=partition_method, coords=coords, ).to(device) else: partition_ids = torch.zeros(shape[0], dtype=torch.int64, device=device) dist.broadcast(partition_ids, src=0) # Build Partition struct + extract local SparseTensor on this rank. partition = build_partition( row_indices, col_indices, int(shape[0]), partition_ids.cpu(), rank, ) from ..sparse_tensor import SparseTensor A = SparseTensor(values, row_indices, col_indices, shape) local_st = A.extract_partition(partition) # If no mesh was passed, build a 1-D mesh from the process # group so the result is still a real ``DSparseTensor[Shard(0)]``. if mesh is None: try: from torch.distributed.device_mesh import init_device_mesh except ImportError: from torch.distributed._tensor.device_mesh import init_device_mesh mesh = init_device_mesh(str(device.type), (world_size,)) return cls.from_sparse_local( local_st, mesh, partition, global_shape=tuple(shape), )
[docs] def save( self, directory: Any, rank: Optional[int] = None, verbose: bool = False, ) -> None: """Persist this rank's shard to ``directory``. Convenience for :func:`torch_sla.io.save_dsparse(self, directory)`.""" from ..io import save_dsparse save_dsparse(self, directory, rank=rank, verbose=verbose)
[docs] @classmethod def load( cls, directory: Any, mesh: Any = None, rank: Optional[int] = None, target_world_size: Optional[int] = None, device: Union[str, torch.device] = "cpu", ) -> "DSparseTensor": """Reconstruct a :class:`DSparseTensor` from a directory previously written by :meth:`save` / :func:`save_dsparse` / :func:`save_sparse_sharded`. Pass ``target_world_size=1`` (or call from a single process with no live ``torch.distributed`` group) to gather all shards into one trivial ``mesh=None`` DSparseTensor -- useful for offline inspection of a sharded archive. If ``stored_num_partitions != target_world_size`` and the target is not 1, raises :class:`NotImplementedError` (true cross-world-size repartition is deferred to a future ``redistribute()``). """ from ..io import load_dsparse return load_dsparse(directory, mesh=mesh, rank=rank, target_world_size=target_world_size, device=device)
# ========================================================================= # Properties # ========================================================================= @property def shape(self) -> Tuple[int, int]: """Global matrix shape.""" return self._shape @property def num_partitions(self) -> int: """Number of partitions.""" return self._num_partitions @property def device(self) -> torch.device: """Device of the matrix data.""" return self._device @property def dtype(self) -> torch.dtype: """Data type of matrix values.""" return self._local_tensor.values.dtype @property def nnz(self) -> int: """Local nnz on this rank. Use :meth:`global_nnz` for the sum.""" return int(self._local_tensor.values.numel())
[docs] def global_nnz(self) -> int: """Sum of ``nnz`` across all ranks. Collective; cached.""" cached = getattr(self, "_global_nnz_cache", None) if cached is not None: return cached local = torch.tensor([self.nnz], dtype=torch.long, device=self._device) if DIST_AVAILABLE and dist.is_initialized(): dist.all_reduce(local, op=dist.ReduceOp.SUM) total = int(local.item()) self._global_nnz_cache = total return total
@property def ndim(self) -> int: return 2 @property def sparse_shape(self) -> Tuple[int, int]: return self._shape @property def sparse_dim(self) -> Tuple[int, int]: return (0, 1) @property def batch_shape(self) -> Tuple[int, ...]: return () @property def block_shape(self) -> Tuple[int, ...]: return () @property def batch_size(self) -> int: return 1 @property def is_batched(self) -> bool: return False @property def is_block(self) -> bool: return False @property def is_cuda(self) -> bool: return self._device.type == "cuda" @property def is_square(self) -> bool: M, N = self._shape return M == N @property def values(self) -> torch.Tensor: return self._local_tensor.values @property def row_indices(self) -> torch.Tensor: return self._local_tensor.row_indices @property def col_indices(self) -> torch.Tensor: return self._local_tensor.col_indices
[docs] def to( self, device: Optional[Union[str, torch.device]] = None, dtype: Optional[torch.dtype] = None, ) -> "DSparseTensor": if device is None and dtype is None: return self if isinstance(device, str): device = torch.device(device) new_local = self._local_tensor.to(device=device, dtype=dtype) placement = self._spec.placement # BatchShard has no .partition; only VertexShard variants do. has_partition = isinstance(placement, _VERTEX_SHARDS) new_partition = (placement.partition.to(device) if has_partition and device is not None and placement.partition is not None else (placement.partition if has_partition else None)) out = type(self).__new__(type(self)) out._values = None out._row_indices = None out._col_indices = None out._shape = self._shape out._num_partitions = self._num_partitions out._coords = self._coords out._partition_method = self._partition_method out._verbose = self._verbose out._device = new_local.values.device out._local_tensor = new_local out._halo_send_buffers = {} out._halo_recv_buffers = {} if isinstance(placement, BatchShard): new_placement = placement elif isinstance(placement, VertexShard): new_placement = VertexShard(partition=new_partition) elif isinstance(placement, VertexShardReplicated): new_placement = VertexShardReplicated(partition=new_partition) else: new_placement = placement out._spec = DSparseSpec( placement=new_placement, mesh=self._spec.mesh, global_shape=self._spec.global_shape, ) return out
[docs] def cuda(self, device: Optional[int] = None) -> "DSparseTensor": return self.to("cuda" if device is None else f"cuda:{device}")
[docs] def cpu(self) -> "DSparseTensor": return self.to("cpu")
[docs] def float(self) -> "DSparseTensor": return self.to(dtype=torch.float32)
[docs] def double(self) -> "DSparseTensor": return self.to(dtype=torch.float64)
[docs] def half(self) -> "DSparseTensor": return self.to(dtype=torch.float16)
# Reductions cover stored non-zero values only (matches SparseTensor). # Each rank's _local_tensor holds disjoint owned rows, so summing # local results gives the global value without double-counting. def _all_reduce_scalar(self, value: torch.Tensor, op: "dist.ReduceOp" = None) -> torch.Tensor: if op is None: op = dist.ReduceOp.SUM if DIST_AVAILABLE else None if DIST_AVAILABLE and dist.is_initialized(): dist.all_reduce(value, op=op) return value
[docs] def sum(self, axis: Optional[Union[int, Tuple[int, ...]]] = None, keepdim: bool = False) -> torch.Tensor: """Sum stored values. ``axis`` in ``{None, 0, 1}``.""" if isinstance(self._spec.placement, BatchShard): local_sum = self._local_tensor.values.sum() return self._all_reduce_scalar(local_sum, dist.ReduceOp.SUM) local = self._local_tensor if axis is None: total = (local.values.clone() if local.values.requires_grad else local.values.detach().clone()).sum() return self._all_reduce_scalar(total, dist.ReduceOp.SUM) M, N = self._shape partition = self._spec.placement.partition if axis in (0, -2): idx = partition.local_to_global[local.col_indices] length, keep_axis = N, 0 elif axis in (1, -1): idx = partition.local_to_global[local.row_indices] length, keep_axis = M, 1 else: raise ValueError(f"axis {axis} out of range (None, 0, 1)") out = torch.zeros(length, dtype=local.values.dtype, device=local.values.device) out.scatter_add_(0, idx, local.values) if DIST_AVAILABLE and dist.is_initialized(): dist.all_reduce(out, op=dist.ReduceOp.SUM) return out.unsqueeze(keep_axis) if keepdim else out
[docs] def mean(self, axis: Optional[Union[int, Tuple[int, ...]]] = None) -> torch.Tensor: """Mean over stored values (implicit zeros excluded).""" if isinstance(self._spec.placement, BatchShard): total = self._local_tensor.values.sum() count = torch.tensor([self._local_tensor.values.numel()], dtype=torch.long, device=self._local_tensor.values.device) total = self._all_reduce_scalar(total, dist.ReduceOp.SUM) if DIST_AVAILABLE and dist.is_initialized(): dist.all_reduce(count, op=dist.ReduceOp.SUM) return total / count.clamp_(min=1).to(total.dtype) total = self.sum(axis=axis) if axis is None: return total / max(1, self.global_nnz()) local = self._local_tensor partition = self._spec.placement.partition M, N = self._shape if axis in (0, -2): idx = partition.local_to_global[local.col_indices] length = N else: idx = partition.local_to_global[local.row_indices] length = M counts = torch.zeros(length, dtype=torch.long, device=local.values.device) counts.scatter_add_(0, idx, torch.ones_like(idx)) if DIST_AVAILABLE and dist.is_initialized(): dist.all_reduce(counts, op=dist.ReduceOp.SUM) return total / counts.clamp_(min=1).to(total.dtype)
[docs] def prod(self) -> torch.Tensor: # gloo lacks ReduceOp.PROD, so all_gather + local prod instead. local_p = self._local_tensor.values.prod() if not (DIST_AVAILABLE and dist.is_initialized()): return local_p gathered = [torch.zeros_like(local_p) for _ in range(dist.get_world_size())] dist.all_gather(gathered, local_p) return torch.stack(gathered).prod()
[docs] def max(self) -> torch.Tensor: return self._all_reduce_scalar(self._local_tensor.values.max(), dist.ReduceOp.MAX)
[docs] def min(self) -> torch.Tensor: return self._all_reduce_scalar(self._local_tensor.values.min(), dist.ReduceOp.MIN)
[docs] def norm(self, ord: Any = "fro") -> torch.Tensor: """``'fro'`` / ``1`` / ``inf``. ``2`` requires ``full_tensor().norm(2)``.""" if ord == "fro": v = self._local_tensor.values if v.is_complex(): local_sq = (v.real ** 2 + v.imag ** 2).sum() elif v.dtype in (torch.float16, torch.bfloat16): local_sq = (v.float() ** 2).sum() else: local_sq = (v ** 2).sum() return self._all_reduce_scalar(local_sq, dist.ReduceOp.SUM).sqrt() if isinstance(self._spec.placement, BatchShard): raise NotImplementedError( "BatchShard norm only supports 'fro'; for 1/inf use full_tensor()") if ord == 1: return self._abs_axis_sum(axis=0).max() if ord == float("inf"): return self._abs_axis_sum(axis=1).max() if ord == 2: raise NotImplementedError("use full_tensor().norm(2)") raise ValueError(f"unsupported norm order: {ord!r}")
def _abs_axis_sum(self, axis: int) -> torch.Tensor: local = self._local_tensor partition = self._spec.placement.partition M, N = self._shape abs_v = local.values.abs() if axis in (0, -2): out = torch.zeros(N, dtype=abs_v.dtype, device=abs_v.device) idx = partition.local_to_global[local.col_indices] elif axis in (1, -1): out = torch.zeros(M, dtype=abs_v.dtype, device=abs_v.device) idx = partition.local_to_global[local.row_indices] else: raise ValueError(f"axis must be 0 or 1, got {axis}") out.scatter_add_(0, idx, abs_v) if DIST_AVAILABLE and dist.is_initialized(): dist.all_reduce(out, op=dist.ReduceOp.SUM) return out # Element-wise math: delegate to per-rank SparseTensor, re-wrap with # same spec. Same-spec DSparseTensor + DSparseTensor allowed when COO # patterns match (SparseTensor.__add__ enforces locally). def _wrap_local(self, local: "SparseTensor") -> "DSparseTensor": out = type(self).__new__(type(self)) out._values = None out._row_indices = None out._col_indices = None out._shape = self._shape out._num_partitions = self._num_partitions out._coords = self._coords out._partition_method = self._partition_method out._verbose = self._verbose out._device = local.values.device out._local_tensor = local out._halo_send_buffers = {} out._halo_recv_buffers = {} out._spec = self._spec return out def _coerce_other_local(self, other): if isinstance(other, DSparseTensor): if other._spec.mesh is not self._spec.mesh: raise ValueError("element-wise op: operands must share DeviceMesh") if other._spec.global_shape != self._spec.global_shape: raise ValueError( f"shape mismatch {self._spec.global_shape} vs {other._spec.global_shape}") return other._local_tensor return other def __add__(self, other) -> "DSparseTensor": return self._wrap_local(self._local_tensor + self._coerce_other_local(other)) def __radd__(self, other) -> "DSparseTensor": return self.__add__(other) def __sub__(self, other) -> "DSparseTensor": return self._wrap_local(self._local_tensor - self._coerce_other_local(other)) def __rsub__(self, other) -> "DSparseTensor": return self._wrap_local(other - self._local_tensor) def __mul__(self, other) -> "DSparseTensor": return self._wrap_local(self._local_tensor * self._coerce_other_local(other)) def __rmul__(self, other) -> "DSparseTensor": return self.__mul__(other) def __truediv__(self, other) -> "DSparseTensor": return self._wrap_local(self._local_tensor / self._coerce_other_local(other)) def __pow__(self, exponent) -> "DSparseTensor": return self._wrap_local(self._local_tensor ** exponent) def __neg__(self) -> "DSparseTensor": return self._wrap_local(-self._local_tensor) def __pos__(self) -> "DSparseTensor": return self def __abs__(self) -> "DSparseTensor": return self.abs()
[docs] def abs(self) -> "DSparseTensor": return self._wrap_local(self._local_tensor.abs())
[docs] def sqrt(self) -> "DSparseTensor": return self._wrap_local(self._local_tensor.sqrt())
[docs] def square(self) -> "DSparseTensor": return self._wrap_local(self._local_tensor.square())
[docs] def exp(self) -> "DSparseTensor": return self._wrap_local(self._local_tensor.exp())
[docs] def log(self) -> "DSparseTensor": return self._wrap_local(self._local_tensor.log())
[docs] def conj(self) -> "DSparseTensor": return self._wrap_local(self._local_tensor.conj())
# Topology / structural queries -- collective via cached full_tensor. def _global_view(self) -> "SparseTensor": cached = getattr(self, "_full_tensor_cache", None) if cached is not None: return cached full = self.full_tensor() self._full_tensor_cache = full return full
[docs] def is_symmetric(self, atol: float = 1e-8, rtol: float = 1e-5) -> bool: return bool(self._global_view().is_symmetric(atol=atol, rtol=rtol))
[docs] def is_hermitian(self, atol: float = 1e-8, rtol: float = 1e-5) -> bool: return bool(self._global_view().is_hermitian(atol=atol, rtol=rtol))
[docs] def is_positive_definite(self) -> bool: return bool(self._global_view().is_positive_definite())
[docs] def detect_matrix_type(self) -> str: return self._global_view().detect_matrix_type()
def _gather_warn(self, op_name: str) -> "SparseTensor": """Emit ResourceWarning + allgather. Used by gather-then-compute thin wrappers (det/lu/svd/condition_number) that we don't yet have true distributed implementations for.""" warnings.warn( f"DSparseTensor.{op_name}() falls back to full_tensor() " f"allgather + single-process compute; cost is O(global nnz). " f"Avoid in hot paths.", ResourceWarning, stacklevel=3, ) return self._global_view()
[docs] def det(self) -> torch.Tensor: """Determinant via :meth:`full_tensor` + single-process LU. Warns.""" return self._gather_warn("det").det()
[docs] def lu(self): """LU factorisation via :meth:`full_tensor` + single-process LU. Warns.""" return self._gather_warn("lu").lu()
[docs] def svd(self, k: int = 6): """Truncated SVD via :meth:`full_tensor` + single-process. Warns.""" return self._gather_warn("svd").svd(k=k)
[docs] def condition_number(self, ord: int = 2) -> torch.Tensor: """Condition number via :meth:`full_tensor` + single-process. Warns.""" return self._gather_warn("condition_number").condition_number(ord=ord)
[docs] def logdet(self, **kwargs) -> torch.Tensor: """Distributed log-determinant via Hutchinson + Lanczos. When ``method='hutchinson'`` (the default for SPD), no gather happens -- the trace estimator only needs ``A @ z`` which routes through ``_shard_matvec``. Falls back to ``full_tensor()`` + single-process LU for non-SPD or explicit ``method='lu'``. See :mod:`torch_sla.det` for the full :class:`DetConfig` knobs. """ from ..det import _resolve, _logdet_hutchinson opts = _resolve(**kwargs) method = opts["method"] N = int(self._shape[0]) if method == "auto": try: is_pd = bool(self.is_positive_definite()) except Exception: is_pd = False method = "hutchinson" if is_pd else "lu" if method == "hutchinson": # Distributed Hutchinson. forward Lanczos and backward solves # both go through ``HutchLogDetAdjoint`` so the gradient on # ``self._local_tensor.values`` is wired up. from ..det import HutchLogDetAdjoint partition = self._spec.placement.partition owned = partition.owned_nodes.to(self.device).long() l2g = partition.local_to_global.to(self.device).long() local_st = self._local_tensor row_local_i64 = local_st.row_indices.to(torch.int64) col_local_i64 = local_st.col_indices.to(torch.int64) # Map local CSR coords to global -- the grad closure indexes # the global z / x_solved vectors here, so we need globals. g_row = l2g[row_local_i64] g_col = l2g[col_local_i64] from .collectives import gather_owned_to_global def matvec_fn(z): z_owned = z[owned] if owned.numel() < z.shape[0] else z y_owned = self._shard_matvec(z_owned.contiguous()) return gather_owned_to_global(owned, y_owned, z.shape[0]) def solve_fn(z): # Distributed CG on owned slice. For SPD A, A = A^T. from .solve import cg_shard from ..solve import _active_defaults opts_ = _active_defaults() or {} z_owned = z[owned] if owned.numel() < z.shape[0] else z x_owned = cg_shard( self, z_owned.contiguous(), M_apply=lambda r: r, atol=opts_.get("atol", 1e-8), rtol=opts_.get("rtol", 1e-8), maxiter=opts_.get("maxiter", 200), verbose=False, ) return gather_owned_to_global(owned, x_owned, z.shape[0]) def gather_fn(z, x_solved): # local-nnz grad contribution -- map to global coords. return z[g_row] * x_solved[g_col] return HutchLogDetAdjoint.apply( local_st.values, matvec_fn, solve_fn, gather_fn, opts["num_probes"], opts["lanczos_iter"], opts["distribution"], opts.get("seed", 0), N, self.dtype, self.device, ) # Fallback: gather + single-process logdet return self._gather_warn("logdet").logdet(**kwargs)
[docs] def T(self) -> "DSparseTensor": """Transpose. Allgathers, transposes, repartitions on same mesh.""" full_T = self._global_view().T() if self._spec.mesh is None: return DSparseTensor.from_sparse_local( full_T, mesh=None, partition=self._spec.placement.partition, axis=self._spec.placement.axis, global_shape=tuple(full_T.shape), ) return DSparseTensor.partition( full_T, self._spec.mesh, partition_method=self._partition_method or "simple", coords=self._coords, )
[docs] def H(self) -> "DSparseTensor": return self.conj().T()
[docs] def eigsh(self, k: int = 6, which: str = "LM", maxiter: int = 200, tol: float = 1e-8, return_eigenvectors: bool = True, sigma: Optional[float] = None, verbose: bool = False): """Distributed LOBPCG (SparseShard) or per-batch eigsh (BatchShard). BatchShard returns ``(eigenvalues, eigenvectors)`` whose first axis is the batch axis -- each rank runs SparseTensor.eigsh on its local batch slice, no inter-rank comm. """ if isinstance(self._spec.placement, BatchShard): return self._local_tensor.eigsh( k=k, which=which, sigma=sigma, return_eigenvectors=return_eigenvectors) from .eigsh import eigsh_shard return eigsh_shard(self, k=k, which=which, maxiter=maxiter, tol=tol, return_eigenvectors=return_eigenvectors, sigma=sigma, verbose=verbose)
[docs] def solve_batch_shard(self, b: torch.Tensor, **kwargs) -> torch.Tensor: """Per-batch solve under :class:`BatchShard`. Each rank slices ``b`` to its own batch range and reuses :meth:`SparseTensor.solve_batch` (same-pattern batched solve) on its local values stack. Returns this rank's batch slice of the solution; allgather via :meth:`full_tensor`-style code if you need it globally. Zero inter-rank communication.""" from ..sparse_tensor import SparseTensor placement = self._spec.placement if not isinstance(placement, BatchShard): raise RuntimeError("solve_batch_shard requires BatchShard placement") my_b = b.narrow(placement.axis, placement.start, placement.end - placement.start) local = self._local_tensor M, N = local.sparse_shape template = SparseTensor(local.values[0], local.row_indices, local.col_indices, (M, N)) return template.solve_batch(local.values, my_b, **kwargs)
# ========================================================================= # DTensor Utilities # ========================================================================= # ========================================================================= # Distributed Algorithms (True Distributed, No Gather) # ========================================================================= # ========================================================================= # Methods that require data gather (with warnings) # ========================================================================= # ========================================================================= # Matrix Operations # ========================================================================= def __matmul__(self, x: Union[torch.Tensor, "DTensor"]) -> Union[torch.Tensor, "DTensor"]: """``D @ x``. See :func:`distributed_matvec.matmul_spec` / :func:`distributed_matvec.matmul_batch_shard`.""" from .matvec import matmul_spec, matmul_batch_shard if self._spec is None: raise RuntimeError("DSparseTensor.__matmul__ requires a spec") if isinstance(self._spec.placement, _VERTEX_SHARDS): return matmul_spec(self, x) if isinstance(self._spec.placement, BatchShard): return matmul_batch_shard(self, x) raise RuntimeError( f"DSparseTensor.__matmul__ does not support " f"placement {type(self._spec.placement).__name__}") # ====================================================================== # # Shard(0)-space distributed solve dispatcher. # # Every Krylov method below keeps every vector local (size # ``num_owned``) # and uses ``dist.all_reduce`` for the inner products that CG # needs. Matvec routes through ``_pad_owned_to_local`` so halo # entries are filled by ``halo_exchange`` per iteration. # ====================================================================== #
[docs] def solve_distributed_shard( self, b: Any, *, method: Any = None, preconditioner: Any = None, atol: Any = None, rtol: Any = None, maxiter: Any = None, restart: Any = None, verbose: Any = None, ) -> Any: """Distributed Krylov solve entirely in Shard(0) space. Requires this :class:`DSparseTensor` to carry a real spec (built via :meth:`from_local`). The right-hand side ``b`` may be a ``DTensor[Shard(0)]`` (most common) or a raw ``torch.Tensor`` sized ``num_owned`` for the calling rank; the return value mirrors the input's wrapper. Methods (all live in Shard(0) space): * ``"cg"`` Saad §6.7 conjugate gradient -- SPD systems * ``"bicgstab"`` Saad §7.4 BiCGStab -- non-symmetric, no restart * ``"gmres"`` Saad §6.5 restarted GMRES(m) -- general * ``"fgmres"`` Saad §9.4 flexible GMRES(m) -- variable preconditioner * ``"minres"`` Paige-Saunders MINRES -- symmetric indefinite Inner products go through ``dist.all_reduce`` (sum), matvec through ``halo_exchange`` -- no rank ever sees a global vector. SolverConfig integration ------------------------ Every kwarg defaults to ``None``, meaning "look at the active :class:`SolverConfig` scope on this thread, then fall back to the hard-coded default". The precedence chain matches :func:`solve` -- explicit kwarg → innermost scope → outer scopes (LIFO) → hard-coded default. >>> with SolverConfig(method="bicgstab", atol=1e-8): ... x = D.solve_distributed_shard(b) # picks BiCGStab + 1e-8 ... y = D.solve_distributed_shard(b, atol=1e-12) # kwarg wins """ if self._spec is None or not isinstance( self._spec.placement, _VERTEX_SHARDS): raise RuntimeError( "solve_distributed_shard() requires a DSparseTensor " "with VertexShard placement -- build one via " "DSparseTensor.from_local(local, mesh, ...) or " "DSparseTensor.partition(A, mesh, ...)." ) if not (DIST_AVAILABLE and dist.is_initialized()): raise RuntimeError( "solve_distributed_shard() requires torch.distributed " "to be initialised." ) # Merge with active SolverConfig scope. Explicit kwargs (non- # ``None``) win; otherwise we walk the scope stack via # ``solve._active_defaults`` and fall back to hard-coded. from ..solve import _active_defaults defaults = _active_defaults() def _pick(value, name, hardcoded): if value is not None: return value if name in defaults: return defaults[name] return hardcoded method = _pick(method, "method", "cg") atol = _pick(atol, "atol", 1e-10) rtol = _pick(rtol, "rtol", 0.0) maxiter = _pick(maxiter, "maxiter", 1000) restart = restart if restart is not None else 30 # not in SolverConfig verbose = _pick(verbose, "verbose", False) # ``preconditioner`` is special-cased in SolverConfig because # ``None`` is a legitimate "no preconditioning" choice (the # _UNSET sentinel distinguishes that from "not set"). Mirror # that here: explicit ``None`` means identity precond. if preconditioner is None and "preconditioner" in defaults: preconditioner = defaults["preconditioner"] from . import solve as _ds M_apply = _ds.make_preconditioner(self, preconditioner) if _is_dtensor(b): b_owned = b.to_local() wrap_output = True else: b_owned = b wrap_output = False method_l = method.lower() common = dict(M_apply=M_apply, atol=atol, rtol=rtol, maxiter=maxiter, verbose=verbose) if method_l in ("cg", "pcg"): x_owned = _ds.cg_shard(self, b_owned, **common) elif method_l == "bicgstab": x_owned = _ds.bicgstab_shard(self, b_owned, **common) elif method_l in ("gmres", "fgmres"): x_owned = _ds.gmres_shard( self, b_owned, restart=restart, flexible=(method_l == "fgmres"), **common) elif method_l == "minres": x_owned = _ds.minres_shard(self, b_owned, **common) else: raise ValueError( f"Unknown distributed solve method {method!r}; expected " "one of cg, bicgstab, gmres, fgmres, minres." ) if wrap_output: _DTensor = DTensor # use module-level import with fallback return _DTensor.from_local( x_owned, self._spec.mesh, [Shard(0)]) return x_owned
# ------------------------------------------------------------------ # # Shard(0)-space primitives reused by every Krylov method. # ------------------------------------------------------------------ # def _shard_dot(self, u: torch.Tensor, v: torch.Tensor) -> torch.Tensor: """Global inner product over Shard(0) vectors: local ``torch.dot`` then ``dist.all_reduce(SUM)`` across the mesh.""" d = torch.dot(u, v) dist.all_reduce(d, op=dist.ReduceOp.SUM) return d def _shard_norm(self, u: torch.Tensor) -> torch.Tensor: return self._shard_dot(u, u).sqrt() def _num_owned(self) -> int: """Owned-row count for the Shard(0) layout.""" return int(self._spec.placement.partition.owned_nodes.numel()) def _shard_matvec(self, x_owned: torch.Tensor) -> torch.Tensor: """Hot-path matvec used by Krylov solvers. See :func:`distributed_matvec.shard_matvec`.""" from .matvec import shard_matvec return shard_matvec(self, x_owned) # ------------------------------------------------------------------ # # The preconditioner factory + four Krylov methods (CG / BiCGStab / # GMRES / FGMRES / MINRES) live in :mod:`torch_sla.distributed_solve` # as free functions taking ``self`` as ``D``. ``solve_distributed_shard`` # above dispatches to them. # ------------------------------------------------------------------ # def __repr__(self) -> str: return (f"DSparseTensor(shape={tuple(self._shape)}, " f"num_partitions={self._num_partitions}, " f"local_nnz={self.nnz}, device={self._device})")