Source code for torch_sla.sparse_tensor.list

"""``SparseTensorList`` for batched-with-different-pattern collections."""
from __future__ import annotations

from typing import List, Optional, Tuple, Union, Literal

import torch

from .core import SparseTensor


[docs] class SparseTensorList: """ A list of SparseTensors with different structures. Provides a unified interface for batch operations on matrices with different sparsity patterns. Unlike batched SparseTensor (which requires same structure), SparseTensorList allows each matrix to have different shape and sparsity pattern. Parameters ---------- tensors : List[SparseTensor] List of SparseTensor objects. Attributes ---------- shapes : List[Tuple[int, ...]] List of shapes for each tensor. device : torch.device Device (from first tensor). dtype : torch.dtype Data type (from first tensor). Examples -------- >>> # Create matrices with different sizes >>> A1 = SparseTensor(val1, row1, col1, (10, 10)) >>> A2 = SparseTensor(val2, row2, col2, (20, 20)) >>> A3 = SparseTensor(val3, row3, col3, (30, 30)) >>> # Create list >>> matrices = SparseTensorList([A1, A2, A3]) >>> print(matrices.shapes) # [(10, 10), (20, 20), (30, 30)] >>> # Batch solve >>> x_list = matrices.solve([b1, b2, b3]) >>> # Check properties for all >>> is_sym = matrices.is_symmetric() # [tensor(True), tensor(True), tensor(True)] """ def __init__(self, tensors: List["SparseTensor"]): if not tensors: raise ValueError("SparseTensorList cannot be empty") self._tensors = list(tensors)
[docs] @classmethod def from_coo_list( cls, matrices: List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor, Tuple[int, ...]]], ) -> "SparseTensorList": """ Create from list of COO data tuples. Parameters ---------- matrices : List[Tuple] List of (values, row_indices, col_indices, shape) tuples. Returns ------- SparseTensorList List of SparseTensors. Examples -------- >>> data = [ ... (val1, row1, col1, (10, 10)), ... (val2, row2, col2, (20, 20)), ... ] >>> matrices = SparseTensorList.from_coo_list(data) """ tensors = [ SparseTensor(val, row, col, shape) for val, row, col, shape in matrices ] return cls(tensors)
[docs] @classmethod def from_torch_sparse_list(cls, A_list: List[torch.Tensor]) -> "SparseTensorList": """ Create from list of PyTorch sparse tensors. Parameters ---------- A_list : List[torch.Tensor] List of PyTorch sparse COO tensors. Returns ------- SparseTensorList List of SparseTensors. """ tensors = [SparseTensor.from_torch_sparse(A) for A in A_list] return cls(tensors)
@property def shapes(self) -> List[Tuple[int, ...]]: """List of shapes for each tensor.""" return [t.shape for t in self._tensors] @property def device(self) -> torch.device: """Device of the first tensor.""" return self._tensors[0].device @property def dtype(self) -> torch.dtype: """Data type of the first tensor.""" return self._tensors[0].dtype def __len__(self) -> int: """Number of tensors in the list.""" return len(self._tensors) def __getitem__(self, idx: int) -> "SparseTensor": """ Get tensor by index. Parameters ---------- idx : int Index (supports negative indexing). Returns ------- SparseTensor The tensor at that index. """ if idx < 0: idx = len(self._tensors) + idx return self._tensors[idx] def __iter__(self): """Iterate over tensors.""" return iter(self._tensors)
[docs] def to(self, device: Union[str, torch.device]) -> "SparseTensorList": """ Move all tensors to device. Parameters ---------- device : str or torch.device Target device. Returns ------- SparseTensorList New list with tensors on target device. """ return SparseTensorList([t.to(device) for t in self._tensors])
[docs] def cuda(self) -> "SparseTensorList": """Move all tensors to CUDA.""" return self.to('cuda')
[docs] def cpu(self) -> "SparseTensorList": """Move all tensors to CPU.""" return self.to('cpu')
# ========================================================================= # Arithmetic Operations # ========================================================================= def __matmul__(self, x_list: Union[List[torch.Tensor], torch.Tensor]) -> List[torch.Tensor]: """ Batch matrix-vector/matrix multiplication. Parameters ---------- x_list : List[torch.Tensor] or torch.Tensor If List: one vector/matrix per sparse tensor, each with compatible shape. If Tensor: broadcasted to all matrices (must have compatible shape for all). Returns ------- List[torch.Tensor] List of results [A1 @ x1, A2 @ x2, ...] or [A1 @ x, A2 @ x, ...] Examples -------- >>> matrices = SparseTensorList([A1, A2, A3]) >>> # Per-matrix vectors >>> y_list = matrices @ [x1, x2, x3] >>> # Broadcast same vector >>> y_list = matrices @ x # x applied to all """ if isinstance(x_list, torch.Tensor): # Broadcast same tensor to all return [t @ x_list for t in self._tensors] if len(x_list) != len(self._tensors): raise ValueError(f"Expected {len(self._tensors)} vectors, got {len(x_list)}") return [t @ x for t, x in zip(self._tensors, x_list)] def __add__(self, other: Union["SparseTensorList", float, int]) -> "SparseTensorList": """ Element-wise addition. Parameters ---------- other : SparseTensorList or scalar If SparseTensorList: add corresponding matrices (must have same length). If scalar: add to all matrices. Returns ------- SparseTensorList Result of addition. """ if isinstance(other, SparseTensorList): if len(other) != len(self._tensors): raise ValueError(f"Length mismatch: {len(self._tensors)} vs {len(other)}") return SparseTensorList([a + b for a, b in zip(self._tensors, other._tensors)]) # Scalar addition - add to values return SparseTensorList([ SparseTensor(t.values + other, t.row_indices, t.col_indices, t.shape) for t in self._tensors ]) def __radd__(self, other): return self.__add__(other) def __sub__(self, other: Union["SparseTensorList", float, int]) -> "SparseTensorList": """Element-wise subtraction.""" if isinstance(other, SparseTensorList): if len(other) != len(self._tensors): raise ValueError(f"Length mismatch: {len(self._tensors)} vs {len(other)}") return SparseTensorList([a - b for a, b in zip(self._tensors, other._tensors)]) return SparseTensorList([ SparseTensor(t.values - other, t.row_indices, t.col_indices, t.shape) for t in self._tensors ]) def __rsub__(self, other): return SparseTensorList([ SparseTensor(other - t.values, t.row_indices, t.col_indices, t.shape) for t in self._tensors ]) def __mul__(self, other: Union["SparseTensorList", float, int, torch.Tensor]) -> "SparseTensorList": """ Element-wise multiplication. Parameters ---------- other : SparseTensorList, scalar, or Tensor If SparseTensorList: multiply corresponding matrices element-wise. If scalar/Tensor: multiply all values. Returns ------- SparseTensorList Result of multiplication. """ if isinstance(other, SparseTensorList): if len(other) != len(self._tensors): raise ValueError(f"Length mismatch: {len(self._tensors)} vs {len(other)}") return SparseTensorList([a * b for a, b in zip(self._tensors, other._tensors)]) return SparseTensorList([t * other for t in self._tensors]) def __rmul__(self, other): return self.__mul__(other) def __truediv__(self, other: Union[float, int, torch.Tensor]) -> "SparseTensorList": """Element-wise division by scalar.""" return SparseTensorList([t / other for t in self._tensors]) def __neg__(self) -> "SparseTensorList": """Negate all values.""" return SparseTensorList([-t for t in self._tensors])
[docs] def sum(self, axis: Optional[int] = None) -> Union[List[torch.Tensor], torch.Tensor]: """ Sum values in each matrix. Parameters ---------- axis : int, optional If None: sum all values in each matrix, return List[scalar]. If 0: sum over rows for each matrix. If 1: sum over columns for each matrix. Returns ------- List[torch.Tensor] or torch.Tensor If axis is None: List of scalar tensors (one per matrix). If axis is 0 or 1: List of 1D tensors. Examples -------- >>> matrices = SparseTensorList([A1, A2, A3]) >>> totals = matrices.sum() # [sum(A1), sum(A2), sum(A3)] >>> row_sums = matrices.sum(axis=1) # [A1.sum(1), A2.sum(1), ...] """ return [t.sum(axis=axis) for t in self._tensors]
[docs] def mean(self, axis: Optional[int] = None) -> List[torch.Tensor]: """ Mean of values in each matrix. Parameters ---------- axis : int, optional Same as sum(). Returns ------- List[torch.Tensor] List of mean values/vectors. """ return [t.mean(axis=axis) for t in self._tensors]
[docs] def max(self) -> List[torch.Tensor]: """Maximum value in each matrix.""" return [t.max() for t in self._tensors]
[docs] def min(self) -> List[torch.Tensor]: """Minimum value in each matrix.""" return [t.min() for t in self._tensors]
[docs] def abs(self) -> "SparseTensorList": """Absolute value of all elements.""" return SparseTensorList([t.abs() for t in self._tensors])
[docs] def clamp(self, min: Optional[float] = None, max: Optional[float] = None) -> "SparseTensorList": """Clamp values in all matrices.""" return SparseTensorList([t.clamp(min=min, max=max) for t in self._tensors])
[docs] def pow(self, exponent: float) -> "SparseTensorList": """Element-wise power.""" return SparseTensorList([t.pow(exponent) for t in self._tensors])
[docs] def sqrt(self) -> "SparseTensorList": """Element-wise square root.""" return SparseTensorList([t.sqrt() for t in self._tensors])
[docs] def exp(self) -> "SparseTensorList": """Element-wise exponential.""" return SparseTensorList([t.exp() for t in self._tensors])
[docs] def log(self) -> "SparseTensorList": """Element-wise natural logarithm.""" return SparseTensorList([t.log() for t in self._tensors])
# ========================================================================= # Linear Algebra # =========================================================================
[docs] def solve(self, b_list: List[torch.Tensor], **kwargs) -> List[torch.Tensor]: """ Solve linear systems for all matrices. Parameters ---------- b_list : List[torch.Tensor] List of right-hand side vectors, one per matrix. **kwargs Additional arguments passed to SparseTensor.solve(). Returns ------- List[torch.Tensor] List of solutions. Examples -------- >>> matrices = SparseTensorList([A1, A2, A3]) >>> x_list = matrices.solve([b1, b2, b3]) """ if len(b_list) != len(self._tensors): raise ValueError(f"Expected {len(self._tensors)} RHS vectors, got {len(b_list)}") return [t.solve(b, **kwargs) for t, b in zip(self._tensors, b_list)]
[docs] def is_symmetric(self, **kwargs) -> List[torch.Tensor]: """ Check symmetry for all matrices. Parameters ---------- **kwargs Arguments passed to SparseTensor.is_symmetric(). Returns ------- List[torch.Tensor] List of boolean tensors. """ return [t.is_symmetric(**kwargs) for t in self._tensors]
[docs] def is_positive_definite(self, **kwargs) -> List[torch.Tensor]: """ Check positive definiteness for all matrices. Parameters ---------- **kwargs Arguments passed to SparseTensor.is_positive_definite(). Returns ------- List[torch.Tensor] List of boolean tensors. """ return [t.is_positive_definite(**kwargs) for t in self._tensors]
[docs] def norm(self, ord: Literal['fro', 1, 2] = 'fro') -> List[torch.Tensor]: """ Compute norms for all matrices. Parameters ---------- ord : {'fro', 1, 2} Norm type. Returns ------- List[torch.Tensor] List of norm values. """ return [t.norm(ord=ord) for t in self._tensors]
[docs] def eigs(self, k: int = 6, **kwargs) -> List[Tuple[torch.Tensor, Optional[torch.Tensor]]]: """ Compute eigenvalues for all matrices. Parameters ---------- k : int Number of eigenvalues. **kwargs Additional arguments. Returns ------- List[Tuple[torch.Tensor, Optional[torch.Tensor]]] List of (eigenvalues, eigenvectors) tuples. """ return [t.eigs(k=k, **kwargs) for t in self._tensors]
[docs] def eigsh(self, k: int = 6, **kwargs) -> List[Tuple[torch.Tensor, Optional[torch.Tensor]]]: """ Compute eigenvalues for symmetric matrices. Parameters ---------- k : int Number of eigenvalues. **kwargs Additional arguments. Returns ------- List[Tuple[torch.Tensor, Optional[torch.Tensor]]] List of (eigenvalues, eigenvectors) tuples. """ return [t.eigsh(k=k, **kwargs) for t in self._tensors]
[docs] def svd(self, k: int = 6) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: """ Compute SVD for all matrices. Parameters ---------- k : int Number of singular values. Returns ------- List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]] List of (U, S, Vt) tuples. """ return [t.svd(k=k) for t in self._tensors]
[docs] def condition_number(self, ord: int = 2) -> List[torch.Tensor]: """ Compute condition numbers for all matrices. Parameters ---------- ord : int Norm order. Returns ------- List[torch.Tensor] List of condition numbers. """ return [t.condition_number(ord=ord) for t in self._tensors]
[docs] def det(self) -> List[torch.Tensor]: """ Compute determinants for all matrices. Returns ------- List[torch.Tensor] List of determinant values. Examples -------- >>> matrices = SparseTensorList([A1, A2, A3]) >>> dets = matrices.det() >>> print([d.item() for d in dets]) """ return [t.det() for t in self._tensors]
[docs] def spy( self, indices: Optional[List[int]] = None, ncols: int = 3, figsize: Optional[Tuple[float, float]] = None, **kwargs ): """ Visualize sparsity patterns for multiple matrices in a grid. Parameters ---------- indices : List[int], optional Which matrices to visualize. Default: all. ncols : int, optional Number of columns in subplot grid. Default: 3. figsize : Tuple[float, float], optional Figure size. Auto-computed if None. **kwargs Additional arguments passed to SparseTensor.spy(). Returns ------- fig : matplotlib.figure.Figure The figure object. Examples -------- >>> matrices = SparseTensorList([A1, A2, A3, A4]) >>> matrices.spy() # Visualize all in grid >>> matrices.spy(indices=[0, 2]) # Visualize specific ones """ try: import matplotlib.pyplot as plt except ImportError: raise ImportError("matplotlib is required for spy(). Install with: pip install matplotlib") if indices is None: indices = list(range(len(self._tensors))) n = len(indices) nrows = (n + ncols - 1) // ncols if figsize is None: figsize = (4 * ncols, 4 * nrows) fig, axes = plt.subplots(nrows, ncols, figsize=figsize, squeeze=False) for i, idx in enumerate(indices): row, col = i // ncols, i % ncols ax = axes[row, col] self._tensors[idx].spy(ax=ax, show_colorbar=False, **kwargs) M, N = self._tensors[idx].sparse_shape ax.set_title(f'[{idx}] {M}×{N}, nnz={self._tensors[idx].nnz:,}', fontsize=10) # Hide unused subplots for i in range(n, nrows * ncols): row, col = i // ncols, i % ncols axes[row, col].axis('off') plt.tight_layout() return fig
# ========================================================================= # Conversion Methods # =========================================================================
[docs] def to_block_diagonal(self) -> "SparseTensor": """ Merge all matrices into a single block-diagonal SparseTensor. Creates a sparse matrix where each input matrix appears as a block on the diagonal: diag(A1, A2, ..., An). Returns ------- SparseTensor Block-diagonal matrix with shape (sum(M_i), sum(N_i)). Notes ----- The resulting matrix has the structure: ``` [A1 0 0 ...] [ 0 A2 0 ...] [ 0 0 A3 ...] [... ... ... ] ``` Examples -------- >>> A1 = SparseTensor(val1, row1, col1, (10, 10)) >>> A2 = SparseTensor(val2, row2, col2, (20, 20)) >>> stl = SparseTensorList([A1, A2]) >>> A_block = stl.to_block_diagonal() # Shape (30, 30) """ if len(self._tensors) == 0: raise ValueError("Cannot convert empty SparseTensorList to block diagonal") if len(self._tensors) == 1: return self._tensors[0] # Compute offsets row_offsets = [0] col_offsets = [0] for t in self._tensors: M, N = t.sparse_shape row_offsets.append(row_offsets[-1] + M) col_offsets.append(col_offsets[-1] + N) total_rows = row_offsets[-1] total_cols = col_offsets[-1] # Concatenate all COO data with offsets all_values = [] all_rows = [] all_cols = [] for i, t in enumerate(self._tensors): all_values.append(t.values) all_rows.append(t.row_indices + row_offsets[i]) all_cols.append(t.col_indices + col_offsets[i]) values = torch.cat(all_values) rows = torch.cat(all_rows) cols = torch.cat(all_cols) return SparseTensor(values, rows, cols, (total_rows, total_cols))
[docs] @classmethod def from_block_diagonal( cls, sparse: "SparseTensor", sizes: List[Tuple[int, int]] ) -> "SparseTensorList": """ Split a block-diagonal SparseTensor into a list of matrices. Parameters ---------- sparse : SparseTensor Block-diagonal matrix to split. sizes : List[Tuple[int, int]] List of (rows, cols) for each block. Must sum to sparse.shape. Returns ------- SparseTensorList List of extracted blocks. Examples -------- >>> A_block = SparseTensor(val, row, col, (30, 30)) >>> stl = SparseTensorList.from_block_diagonal(A_block, [(10, 10), (20, 20)]) >>> print(len(stl)) # 2 """ if sparse.is_batched: raise NotImplementedError("from_block_diagonal not supported for batched tensors") # Validate sizes total_rows = sum(s[0] for s in sizes) total_cols = sum(s[1] for s in sizes) if (total_rows, total_cols) != sparse.sparse_shape: raise ValueError( f"Sizes sum to ({total_rows}, {total_cols}) but sparse has shape {sparse.sparse_shape}" ) # Compute offsets row_offsets = [0] col_offsets = [0] for m, n in sizes: row_offsets.append(row_offsets[-1] + m) col_offsets.append(col_offsets[-1] + n) tensors = [] row = sparse.row_indices col = sparse.col_indices val = sparse.values for i, (m, n) in enumerate(sizes): r_start, r_end = row_offsets[i], row_offsets[i + 1] c_start, c_end = col_offsets[i], col_offsets[i + 1] # Find entries in this block mask = (row >= r_start) & (row < r_end) & (col >= c_start) & (col < c_end) block_row = row[mask] - r_start block_col = col[mask] - c_start block_val = val[mask] tensors.append(SparseTensor(block_val, block_row, block_col, (m, n))) return cls(tensors)
@property def block_sizes(self) -> List[Tuple[int, int]]: """ Get the (rows, cols) size of each matrix. Returns ------- List[Tuple[int, int]] List of (M, N) tuples. """ return [t.sparse_shape for t in self._tensors] @property def total_nnz(self) -> int: """Total number of non-zeros across all matrices.""" return sum(t.nnz for t in self._tensors) @property def total_shape(self) -> Tuple[int, int]: """Shape of the block-diagonal representation.""" total_rows = sum(t.sparse_shape[0] for t in self._tensors) total_cols = sum(t.sparse_shape[1] for t in self._tensors) return (total_rows, total_cols) def __repr__(self) -> str: return f"SparseTensorList(n={len(self._tensors)}, device={self.device})"