Architecture¶
torch-sla’s class hierarchy and distributed model mirror PyTorch’s own
torch.Tensor / torch.distributed.tensor.DTensor split: a single
sparse “local data” class, and a thin distributed wrapper that adds
placement + mesh metadata on top. This page is the source of truth for
the design contracts every new feature should respect.
Class hierarchy¶
Role |
PyTorch |
torch-sla |
|---|---|---|
Local data |
|
|
Distributed wrapper (local data + spec) |
|
|
Per-rank local chunk |
|
|
Distributed metadata |
|
|
Sharding placement |
|
|
Key invariant: SparseTensor is always local data. DSparseTensor
is always a distributed wrapper holding one rank’s SparseTensor plus
a spec. No “hybrid” class.
Shape contract: (*batch, M, N, *block)¶
A SparseTensor always has the canonical shape
shape = (*batch_shape, M, N, *block_shape)
└──dense───┘ └sparse└──dense──┘
leading 2 dims trailing
The two sparse dimensions are always the matrix axes M and
N – they cannot move. Dense axes flank them:
batch_shape(left) – dense batch dims for batched SpMV / solveblock_shape(right) – dense block dims for block-sparse formats (BSR / BCSC)
If a user has a tensor where the sparse axes aren’t in this slot,
SparseTensor.permute(...) reorders to the canonical layout. The
contract is positional, not by sparse-dim metadata, so every algorithm
(matvec, solve, eigsh, …) knows where to look.
Placement vocabulary¶
A DSparseSpec carries:
placement: how the data is shardedmesh: which devices it’s sharded overglobal_shape: the original full-tensor shape
placement is either a single placement (1-D mesh) or a list, one
element per mesh dimension (multi-D mesh – same convention as
DTensor).
Class |
Axis kind |
Use case |
|---|---|---|
|
– |
Full matrix on every rank |
|
dense axis (batch or block) |
Per-rank gets a slice of batches; no cross-rank communication for SpMV |
|
sparse axis ( |
Irregular row/col partition of the matrix; needs halo exchange or all-reduce |
|
sparse axis |
Minimal-communication SpMV via PaToH / Mondriaan hypergraph cut |
Convenience constructors row_shard() and
col_shard() cover the common 2-D-matrix case:
from torch_sla import row_shard, col_shard, SparseShard
row_shard() # SparseShard(axis=0), plain (M, N)
col_shard() # SparseShard(axis=1)
row_shard(batch_ndim=2) # SparseShard(axis=2), for (B1, B2, M, N) tensor
Multi-axis sharding on a 2-D mesh: pass a list, exactly like DTensor.
from torch.distributed.tensor import Shard
from torch_sla import SparseShard
# 2-D mesh: 4 batch shards × 8 row shards
mesh = init_device_mesh("cuda", (4, 8))
placement = [Shard(0), # mesh dim 0: dense batch dim
SparseShard(axis=2)] # mesh dim 1: sparse row axis (batch_ndim=2)
Matvec dispatch by placement¶
DSparseTensor.__matmul__ dispatches on placement to pick the right
communication pattern. Each row of this table is a separate code path:
Placement |
matvec algorithm |
Cross-rank communication |
|---|---|---|
|
local |
none |
|
per-batch independent SpMV |
none (embarrassingly parallel) |
|
halo exchange + local SpMV |
O(halo nnz) point-to-point |
|
local partial SpMV + |
O(M) all-reduce |
2-D placement list |
2-D Cannon / SUMMA |
O(sqrt) better than 1-D for large mesh |
Partition algorithms¶
Picking how rows / cols get distributed is its own subproblem. The options below are scoring tradeoffs and Python-binding maturity for each:
Algorithm |
vs. METIS quality |
Python binding |
Best for |
|---|---|---|---|
|
much worse (no locality) |
n/a (~10 LOC) |
sanity tests, deterministic across ranks |
METIS (current default) |
baseline |
|
graphs up to ~100M nodes |
Hilbert space-filling curve |
worse but ~10-100x faster |
pure Python or |
PDE meshes / geometric structure |
KaHIP |
+20% quality |
|
graphs up to ~1B nodes |
Mt-METIS |
same quality, 4-16x faster |
no Python binding; C call from ctypes |
mid-size users with many CPU cores |
PaToH (hypergraph) |
minimal SpMV communication – theoretical optimum |
|
sparse matvec specifically |
Mondriaan |
similar to PaToH, 2-D-specific |
command-line wrapper |
sparse matrices specifically |
ParMETIS |
METIS-quality, distributed |
no Python binding (MPI C only) |
true HPC clusters |
GNN-based learned |
research-grade |
DIY implementation |
very-large / streaming graphs (>1B edges) |
torch-sla’s partition_for_rank() exposes
the partition through the partition_method kwarg. Today it
supports simple, metis, rcb, slicing. hilbert and
patoh are tracked follow-ups.
Why this matches DTensor¶
Every design choice above maps 1:1 onto a corresponding DTensor decision:
SparseTensor≅torch.Tensor– same “local data” role.DSparseTensor≅DTensor– same “(local + spec)” structure.SparseShard(axis)≅Shard(dim)– one parameterised placement per sharded direction, not separate classes.Placement list over mesh dims ≅ DTensor’s multi-axis sharding.
Spec’s
mesh+global_shapeseparation ≅ DTensorSpec.
By staying parallel to DTensor, torch-sla composes cleanly with
PyTorch’s distributed ecosystem (FSDP, TP, DCP) – a sparse vector
result from DSparseTensor.matvec is already a DTensor with the
right placement, ready to feed into a downstream FSDP module.