"""
SparseTensor wrapper class for PyTorch sparse tensors.
Supports batched and block sparse tensors with shape [...batch, M, N, ...block]:
- Leading dimensions: batch dimensions [B1, B2, ...]
- Matrix dimensions: (M, N) at positions (sparse_dim[0], sparse_dim[1]), default (-2, -1)
- Trailing dimensions: block dimensions [K1, K2, ...]
Key Features:
- Automatic symmetry and positive definiteness detection
- Sparse linear equation solving with gradient support
- Sparse-sparse multiplication with sparse gradients
- Batched operations for all methods
- CUDA support with LOBPCG for eigenvalue computation
Examples
--------
>>> # Create a simple sparse matrix
>>> val = torch.tensor([4.0, -1.0, -1.0, 4.0])
>>> row = torch.tensor([0, 0, 1, 1])
>>> col = torch.tensor([0, 1, 0, 1])
>>> A = SparseTensor(val, row, col, (2, 2))
>>>
>>> # Check properties (returns boolean tensor for batched)
>>> is_sym = A.is_symmetric() # tensor(True)
>>> is_pd = A.is_positive_definite() # tensor(True)
>>>
>>> # Solve linear system
>>> b = torch.tensor([1.0, 2.0])
>>> x = A.solve(b)
>>>
>>> # Matrix operations
>>> y = A @ x # Sparse @ Dense
>>> C = A @ A # Sparse @ Sparse (sparse gradient)
"""
import os
import torch
from torch.autograd.function import Function
from typing import Tuple, Optional, Union, Literal, List, Dict
import warnings
import math
from ..backends import (
is_scipy_available,
is_eigen_available,
is_cupy_available,
is_cudss_available,
select_backend,
select_method,
BackendType,
MethodType,
)
from ..backends.scipy_backend import (
scipy_solve,
scipy_eigs,
scipy_eigsh,
scipy_svds,
scipy_norm,
scipy_lu,
scipy_det,
)
from .autograd import (
DetAdjoint,
EigshAdjoint,
SparseSolveFunction,
SparseSparseMatmulFunction,
_sparse_sparse_matmul_with_sparse_grad,
)
# =============================================================================
# Utility Functions
# =============================================================================
[docs]
class SparseTensor:
"""
Wrapper class for PyTorch sparse tensors with batched and block support.
Supports tensors with shape [...batch, M, N, ...block] where:
- Leading dimensions [...batch] are batch dimensions
- (M, N) are the sparse matrix dimensions (at sparse_dim positions)
- Trailing dimensions [...block] are block dimensions
Parameters
----------
values : torch.Tensor
Non-zero values with shape:
- Simple: [nnz]
- Batched: [...batch, nnz]
- Block: [nnz, *block_shape]
- Batched+Block: [...batch, nnz, *block_shape]
row_indices : torch.Tensor
Row indices with shape [nnz]. Must be on the same device as values.
col_indices : torch.Tensor
Column indices with shape [nnz]. Must be on the same device as values.
shape : Tuple[int, ...]
Full tensor shape [...batch, M, N, *block_shape].
sparse_dim : Tuple[int, int], optional
Which dimensions are sparse (M, N). Default: (-2, -1) meaning last two
before any block dimensions.
Attributes
----------
values : torch.Tensor
The non-zero values.
row_indices : torch.Tensor
Row indices of non-zeros.
col_indices : torch.Tensor
Column indices of non-zeros.
shape : Tuple[int, ...]
Full tensor shape.
sparse_shape : Tuple[int, int]
The (M, N) dimensions.
batch_shape : Tuple[int, ...]
The batch dimensions.
block_shape : Tuple[int, ...]
The block dimensions.
Examples
--------
**1. Simple 2D Sparse Matrix [M, N]**
>>> import torch
>>> from torch_sla import SparseTensor
>>>
>>> # Create a 3x3 tridiagonal matrix in COO format
>>> val = torch.tensor([4.0, -1.0, -1.0, 4.0, -1.0, -1.0, 4.0])
>>> row = torch.tensor([0, 0, 1, 1, 1, 2, 2])
>>> col = torch.tensor([0, 1, 0, 1, 2, 1, 2])
>>> A = SparseTensor(val, row, col, (3, 3))
>>> print(A)
SparseTensor(shape=(3, 3), sparse=(3, 3), nnz=7, dtype=torch.float64, device=cpu)
>>>
>>> # Solve Ax = b
>>> b = torch.tensor([1.0, 2.0, 3.0])
>>> x = A.solve(b)
**2. Batched Sparse Matrices [B, M, N]**
Same sparsity pattern, different values for each batch.
>>> # 4 matrices, each 3x3, same structure
>>> batch_size = 4
>>> val_batch = val.unsqueeze(0).expand(batch_size, -1).clone() # [4, 7]
>>> for i in range(batch_size):
... val_batch[i] = val * (1.0 + 0.1 * i) # Scale each matrix
>>>
>>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3))
>>> print(A_batch.batch_shape) # (4,)
>>> print(A_batch.sparse_shape) # (3, 3)
>>>
>>> # Batched solve
>>> b_batch = torch.randn(4, 3)
>>> x_batch = A_batch.solve(b_batch) # [4, 3]
**3. Multi-Dimensional Batch [B1, B2, M, N]**
>>> B1, B2 = 2, 3 # e.g., 2 materials x 3 temperatures
>>> val_batch = val.unsqueeze(0).unsqueeze(0).expand(B1, B2, -1).clone() # [2, 3, 7]
>>> A_multi = SparseTensor(val_batch, row, col, (B1, B2, 3, 3))
>>> print(A_multi.batch_shape) # (2, 3)
>>>
>>> b_multi = torch.randn(B1, B2, 3)
>>> x_multi = A_multi.solve(b_multi) # [2, 3, 3]
**4. Block Sparse Matrix [M, N, K, K] (Block Size K)**
Each non-zero entry is a KxK dense block instead of a scalar.
>>> # 2x2 block matrix with 2x2 blocks = 4x4 total
>>> block_size = 2
>>> nnz = 3 # 3 non-zero blocks
>>>
>>> # Values: [nnz, K, K] = [3, 2, 2]
>>> val_block = torch.randn(nnz, block_size, block_size)
>>> row_block = torch.tensor([0, 0, 1]) # Block row indices
>>> col_block = torch.tensor([0, 1, 1]) # Block col indices
>>>
>>> # Shape: (num_block_rows, num_block_cols, block_size, block_size)
>>> A_block = SparseTensor(val_block, row_block, col_block, (2, 2, 2, 2))
>>> print(A_block.block_shape) # (2, 2)
>>> print(A_block.sparse_shape) # (2, 2) - number of blocks
>>> print(A_block.shape) # (2, 2, 2, 2) - full shape
**5. Batched Block Sparse [B, M, N, K, K]**
>>> batch_size = 4
>>> val_batch_block = torch.randn(batch_size, nnz, block_size, block_size) # [4, 3, 2, 2]
>>> A_batch_block = SparseTensor(val_batch_block, row_block, col_block, (4, 2, 2, 2, 2))
>>> print(A_batch_block.batch_shape) # (4,)
>>> print(A_batch_block.block_shape) # (2, 2)
**6. Create from Dense Matrix**
>>> A_dense = torch.randn(100, 100)
>>> A_dense[A_dense.abs() < 0.5] = 0 # Sparsify
>>> A = SparseTensor.from_dense(A_dense)
**7. Create from PyTorch Sparse Tensor**
>>> A_torch = torch.randn(100, 100).to_sparse_coo()
>>> A = SparseTensor.from_torch_sparse(A_torch)
**8. Property Detection**
>>> A = SparseTensor(val, row, col, (3, 3))
>>> A.is_symmetric() # tensor(True) - returns tensor for batch support
>>> A.is_positive_definite() # tensor(True)
>>> A.is_positive_definite('cholesky') # Use Cholesky factorization check
**9. Matrix Operations**
>>> # Matrix-vector multiply
>>> y = A @ x # SparseTensor @ dense vector
>>>
>>> # Sparse-sparse multiply (returns SparseTensor with sparse gradients)
>>> C = A @ A
>>>
>>> # Norms
>>> A.norm('fro') # Frobenius norm
>>>
>>> # Eigenvalues (symmetric matrices)
>>> eigenvalues, eigenvectors = A.eigsh(k=2, which='LM')
**10. CUDA Support**
>>> A_cuda = A.cuda()
>>> x = A_cuda.solve(b.cuda()) # Uses cuDSS or CuPy
"""
def __init__(
self,
values: torch.Tensor,
row_indices: torch.Tensor,
col_indices: torch.Tensor,
shape: Tuple[int, ...],
sparse_dim: Tuple[int, int] = (-2, -1),
):
self.values = values
self.row_indices = row_indices
self.col_indices = col_indices
self._shape = tuple(shape)
self._sparse_dim = self._normalize_sparse_dim(sparse_dim, len(shape))
# Cache for computed properties
self._is_symmetric_cache = None
self._is_hermitian_cache = None
self._is_positive_definite_cache = None
self._validate()
def _normalize_sparse_dim(self, sparse_dim: Tuple[int, int], ndim: int) -> Tuple[int, int]:
"""Normalize negative indices in sparse_dim."""
dim_m = sparse_dim[0] if sparse_dim[0] >= 0 else ndim + sparse_dim[0]
dim_n = sparse_dim[1] if sparse_dim[1] >= 0 else ndim + sparse_dim[1]
return (dim_m, dim_n)
def _validate(self):
"""Validate tensor dimensions and indices."""
ndim = len(self._shape)
dim_m, dim_n = self._sparse_dim
if ndim < 2:
raise ValueError(f"Shape must have at least 2 dimensions, got {ndim}")
if not (0 <= dim_m < ndim and 0 <= dim_n < ndim):
raise ValueError(f"sparse_dim {self._sparse_dim} out of range for shape {self._shape}")
if dim_m == dim_n:
raise ValueError(f"sparse_dim dimensions must be different")
# =========================================================================
# Class Methods
# =========================================================================
[docs]
@classmethod
def from_dense(
cls,
A: torch.Tensor,
sparse_dim: Tuple[int, int] = (-2, -1)
) -> "SparseTensor":
"""
Create SparseTensor from dense tensor.
Parameters
----------
A : torch.Tensor
Dense tensor with shape [...batch, M, N, ...block].
sparse_dim : Tuple[int, int], optional
Which dimensions are sparse. Default: (-2, -1).
Returns
-------
SparseTensor
Sparse representation of A.
Examples
--------
>>> A_dense = torch.randn(3, 3)
>>> A_dense[A_dense.abs() < 0.5] = 0
>>> A = SparseTensor.from_dense(A_dense)
"""
ndim = A.dim()
dim_m = sparse_dim[0] if sparse_dim[0] >= 0 else ndim + sparse_dim[0]
dim_n = sparse_dim[1] if sparse_dim[1] >= 0 else ndim + sparse_dim[1]
if ndim == 2 and dim_m == 0 and dim_n == 1:
A_sparse = A.to_sparse_coo().coalesce()
indices = A_sparse.indices()
values = A_sparse.values()
return cls(values, indices[0], indices[1], tuple(A.shape), sparse_dim=sparse_dim)
perm = [i for i in range(ndim) if i not in (dim_m, dim_n)] + [dim_m, dim_n]
A_perm = A.permute(*perm)
batch_shape = A_perm.shape[:-2]
M, N = A_perm.shape[-2], A_perm.shape[-1]
A_flat = A_perm.reshape(-1, M, N)
A_2d = A_flat[0].to_sparse_coo()
indices = A_2d._indices()
row = indices[0]
col = indices[1]
nnz = row.size(0)
values = A_flat[:, row, col]
if len(batch_shape) > 0:
values = values.reshape(*batch_shape, nnz)
else:
values = values.squeeze(0)
return cls(values, row, col, tuple(A.shape), sparse_dim=sparse_dim)
[docs]
@classmethod
def from_torch_sparse(cls, A: torch.Tensor) -> "SparseTensor":
"""
Create SparseTensor from PyTorch sparse tensor.
Parameters
----------
A : torch.Tensor
PyTorch sparse COO or CSR tensor (2D only).
Returns
-------
SparseTensor
SparseTensor representation.
Examples
--------
>>> A_coo = torch.randn(3, 3).to_sparse_coo()
>>> A = SparseTensor.from_torch_sparse(A_coo)
"""
if A.layout == torch.sparse_csr:
A = A.to_sparse_coo()
A = A.coalesce()
indices = A.indices()
values = A.values()
return cls(values, indices[0], indices[1], tuple(A.shape))
[docs]
@classmethod
def eye(cls, n: int, dtype: torch.dtype = torch.float64,
device: Union[str, torch.device] = "cpu") -> "SparseTensor":
"""Sparse identity ``n x n``."""
idx = torch.arange(n, dtype=torch.int64, device=device)
return cls(torch.ones(n, dtype=dtype, device=device), idx, idx, shape=(n, n))
[docs]
@classmethod
def diag(cls, values: torch.Tensor,
device: Optional[Union[str, torch.device]] = None) -> "SparseTensor":
"""Sparse diagonal matrix from a 1-D vector."""
if values.dim() != 1:
raise ValueError(f"diag needs a 1-D tensor, got shape {tuple(values.shape)}")
n = int(values.numel())
device = device if device is not None else values.device
idx = torch.arange(n, dtype=torch.int64, device=device)
return cls(values.to(device), idx, idx, shape=(n, n))
[docs]
@classmethod
def tridiagonal(cls, n: int,
diag: Union[float, torch.Tensor] = 2.0,
off_diag: Union[float, torch.Tensor] = -1.0,
dtype: torch.dtype = torch.float64,
device: Union[str, torch.device] = "cpu") -> "SparseTensor":
"""Sparse symmetric tridiagonal ``n x n``. ``diag=4, off=-1`` is the
canonical SPD test matrix; ``diag=2, off=-1`` is the 1-D Laplacian.
``diag`` / ``off_diag`` accept scalars or matching-length tensors."""
device = torch.device(device)
def _vec(v, length, name):
if isinstance(v, torch.Tensor):
if v.dim() != 1 or v.numel() != length:
raise ValueError(f"{name} must have shape ({length},), got {tuple(v.shape)}")
return v.to(device=device, dtype=dtype)
return torch.full((length,), float(v), dtype=dtype, device=device)
diag_v = _vec(diag, n, "diag")
off_v = _vec(off_diag, n - 1, "off_diag")
idx = torch.arange(n, dtype=torch.int64, device=device)
vals = torch.cat([diag_v, off_v, off_v])
row = torch.cat([idx, idx[1:], idx[:-1]])
col = torch.cat([idx, idx[:-1], idx[1:]])
return cls(vals, row, col, shape=(n, n))
# =========================================================================
# Properties
# =========================================================================
@property
def shape(self) -> Tuple[int, ...]:
"""Full tensor shape [...batch, M, N, ...block]."""
return self._shape
@property
def sparse_shape(self) -> Tuple[int, int]:
"""The (M, N) sparse matrix dimensions."""
dim_m, dim_n = self._sparse_dim
return (self._shape[dim_m], self._shape[dim_n])
@property
def batch_shape(self) -> Tuple[int, ...]:
"""The batch dimensions before the sparse dimensions."""
dim_m, dim_n = self._sparse_dim
min_dim = min(dim_m, dim_n)
return self._shape[:min_dim]
@property
def block_shape(self) -> Tuple[int, ...]:
"""The block dimensions after the sparse dimensions."""
dim_m, dim_n = self._sparse_dim
max_dim = max(dim_m, dim_n)
return self._shape[max_dim + 1:]
@property
def sparse_dim(self) -> Tuple[int, int]:
"""The dimensions that are sparse (M, N)."""
return self._sparse_dim
@property
def ndim(self) -> int:
"""Number of dimensions."""
return len(self._shape)
@property
def nnz(self) -> int:
"""Number of non-zero elements (per batch/block)."""
return self.row_indices.size(0)
@property
def dtype(self) -> torch.dtype:
"""Data type of the values."""
return self.values.dtype
@property
def device(self) -> torch.device:
"""Device of the tensor."""
return self.values.device
@property
def is_cuda(self) -> bool:
"""Whether the tensor is on CUDA."""
return self.values.is_cuda
@property
def is_batched(self) -> bool:
"""Whether the tensor has batch dimensions."""
return len(self.batch_shape) > 0
@property
def is_block(self) -> bool:
"""Whether the tensor has block dimensions."""
return len(self.block_shape) > 0
@property
def batch_size(self) -> int:
"""Total number of batch elements (product of batch_shape)."""
return math.prod(self.batch_shape) if self.batch_shape else 1
@property
def is_square(self) -> bool:
"""Whether the sparse dimensions are square (M == N)."""
M, N = self.sparse_shape
return M == N
# =========================================================================
# Device and Type Management
# =========================================================================
[docs]
def to(
self,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None
) -> "SparseTensor":
"""
Move tensor to device and/or convert dtype.
Parameters
----------
device : str or torch.device, optional
Target device (e.g., 'cuda', 'cpu', 'cuda:0').
dtype : torch.dtype, optional
Target data type (e.g., torch.float32, torch.float64).
Returns
-------
SparseTensor
New SparseTensor on the target device/dtype.
Examples
--------
>>> A = SparseTensor(val, row, col, shape)
>>> A_cuda = A.to('cuda')
>>> A_float32 = A.to(dtype=torch.float32)
>>> A_cuda_float32 = A.to('cuda', torch.float32)
"""
new_values = self.values
new_row = self.row_indices
new_col = self.col_indices
if device is not None:
new_values = new_values.to(device)
new_row = new_row.to(device)
new_col = new_col.to(device)
if dtype is not None:
new_values = new_values.to(dtype)
result = SparseTensor(
new_values, new_row, new_col, self._shape,
sparse_dim=self._sparse_dim
)
return result
[docs]
def cuda(self, device: Optional[int] = None) -> "SparseTensor":
"""
Move tensor to CUDA device.
Parameters
----------
device : int, optional
CUDA device index. Default: current device.
Returns
-------
SparseTensor
Tensor on CUDA.
"""
if device is None:
return self.to('cuda')
return self.to(f'cuda:{device}')
[docs]
def cpu(self) -> "SparseTensor":
"""
Move tensor to CPU.
Returns
-------
SparseTensor
Tensor on CPU.
"""
return self.to('cpu')
[docs]
def float(self) -> "SparseTensor":
"""Convert to float32."""
return self.to(dtype=torch.float32)
[docs]
def double(self) -> "SparseTensor":
"""Convert to float64."""
return self.to(dtype=torch.float64)
[docs]
def half(self) -> "SparseTensor":
"""Convert to float16."""
return self.to(dtype=torch.float16)
[docs]
def requires_grad_(self, requires_grad: bool = True) -> "SparseTensor":
"""Mark the underlying ``values`` tensor as requiring gradient.
Mirrors :meth:`torch.Tensor.requires_grad_`: returns ``self`` so
the call can be chained, and flips ``self.values.requires_grad``
in-place. Indices (``row_indices`` / ``col_indices``) are
non-differentiable by construction and unaffected.
The 0.2.x line carried this as a bound method; the 0.3 refactor
moved the body into :mod:`.ops` but forgot to re-expose a
delegating shim on the class, breaking every downstream caller
that did ``A.requires_grad_(True)`` (TensorMesh backward tests
in particular). Restore the delegation.
"""
from .ops import requires_grad_ as _impl
return _impl(self, requires_grad)
[docs]
def to_torch_sparse(self, *args, **kwargs):
from .convert import to_torch_sparse as _impl
return _impl(self, *args, **kwargs)
[docs]
def to_dense(self, *args, **kwargs):
from .convert import to_dense as _impl
return _impl(self, *args, **kwargs)
[docs]
def to_csr(self, *args, **kwargs):
from .convert import to_csr as _impl
return _impl(self, *args, **kwargs)
[docs]
def save_distributed(self, *args, **kwargs):
from .convert import save_distributed as _impl
return _impl(self, *args, **kwargs)
[docs]
def partition_for_rank(self, *args, **kwargs):
from .convert import partition_for_rank as _impl
return _impl(self, *args, **kwargs)
[docs]
def detect_matrix_type(self, *args, **kwargs):
from .convert import detect_matrix_type as _impl
return _impl(self, *args, **kwargs)
[docs]
def T(self, *args, **kwargs):
from .convert import T as _impl
return _impl(self, *args, **kwargs)
[docs]
def conj(self, *args, **kwargs):
from .convert import conj as _impl
return _impl(self, *args, **kwargs)
[docs]
def H(self, *args, **kwargs):
from .convert import H as _impl
return _impl(self, *args, **kwargs)
[docs]
def flatten_blocks(self, *args, **kwargs):
from .convert import flatten_blocks as _impl
return _impl(self, *args, **kwargs)
[docs]
def unflatten_blocks(self, *args, **kwargs):
from .convert import unflatten_blocks as _impl
return _impl(self, *args, **kwargs)
[docs]
def is_symmetric(self, *args, **kwargs):
from .structural import is_symmetric as _impl
return _impl(self, *args, **kwargs)
[docs]
def is_hermitian(self, *args, **kwargs):
from .structural import is_hermitian as _impl
return _impl(self, *args, **kwargs)
[docs]
def is_positive_definite(self, *args, **kwargs):
from .structural import is_positive_definite as _impl
return _impl(self, *args, **kwargs)
def _check_pair_match(self, *args, **kwargs):
from .structural import _check_pair_match as _impl
return _impl(self, *args, **kwargs)
def _check_pd_gershgorin(self, *args, **kwargs):
from .structural import _check_pd_gershgorin as _impl
return _impl(self, *args, **kwargs)
def _check_pd_cholesky(self, *args, **kwargs):
from .structural import _check_pd_cholesky as _impl
return _impl(self, *args, **kwargs)
def _check_pd_eigenvalue(self, *args, **kwargs):
from .structural import _check_pd_eigenvalue as _impl
return _impl(self, *args, **kwargs)
def _batch_indices(self, *args, **kwargs):
from .structural import _batch_indices as _impl
return _impl(self, *args, **kwargs)
[docs]
def connected_components(self, *args, **kwargs):
from .graph import connected_components as _impl
return _impl(self, *args, **kwargs)
[docs]
def has_isolated_components(self, *args, **kwargs):
from .graph import has_isolated_components as _impl
return _impl(self, *args, **kwargs)
[docs]
def to_connected_components(self, *args, **kwargs):
from .graph import to_connected_components as _impl
return _impl(self, *args, **kwargs)
def _spmv_coo(self, *args, **kwargs):
from .matmul import _spmv_coo as _impl
return _impl(self, *args, **kwargs)
def _dense_sparse_mm(self, *args, **kwargs):
from .matmul import _dense_sparse_mm as _impl
return _impl(self, *args, **kwargs)
def _spsp_multiply(self, *args, **kwargs):
from .matmul import _spsp_multiply as _impl
return _impl(self, *args, **kwargs)
def __matmul__(self, *args, **kwargs):
from .matmul import __matmul__ as _impl
return _impl(self, *args, **kwargs)
def __rmatmul__(self, *args, **kwargs):
from .matmul import __rmatmul__ as _impl
return _impl(self, *args, **kwargs)
[docs]
def solve(self, *args, **kwargs):
from .linalg import solve as _impl
return _impl(self, *args, **kwargs)
[docs]
def solve_batch(self, *args, **kwargs):
from .linalg import solve_batch as _impl
return _impl(self, *args, **kwargs)
[docs]
def nonlinear_solve(self, *args, **kwargs):
from .linalg import nonlinear_solve as _impl
return _impl(self, *args, **kwargs)
# =========================================================================
# Norms
# =========================================================================
[docs]
def norm(self, ord: Literal['fro', 1, 2] = 'fro') -> torch.Tensor:
"""
Compute matrix norm.
For batched tensors, returns norm for each batch element.
Parameters
----------
ord : {'fro', 1, 2}, optional
Norm type:
- 'fro': Frobenius norm (default)
- 1: Maximum absolute column sum
- 2: Spectral norm (largest singular value)
Returns
-------
torch.Tensor
Norm value(s). Shape [] for non-batched, [*batch_shape] for batched.
Examples
--------
>>> A = SparseTensor(val, row, col, (3, 3))
>>> A.norm('fro') # tensor(5.0)
>>> A_batch = SparseTensor(val_batch, row, col, (4, 3, 3))
>>> A_batch.norm('fro') # tensor([5.0, 5.0, 5.0, 5.0])
"""
if self.is_batched:
batch_shape = self.batch_shape
vals_flat = self.values.reshape(-1, self.nnz)
norms = []
for i in range(vals_flat.size(0)):
if ord == 'fro':
norms.append(vals_flat[i].norm())
else:
idx = self._flat_to_batch_idx(i)
A_dense = self.to_dense(idx)
norms.append(torch.linalg.norm(A_dense, ord=ord))
return torch.stack(norms).reshape(*batch_shape)
else:
if ord == 'fro':
return self.values.norm()
if self.is_cuda or not is_scipy_available():
A = self.to_dense()
return torch.linalg.norm(A, ord=ord)
M, N = self.sparse_shape
return scipy_norm(self.values, self.row_indices, self.col_indices, (M, N), ord=ord)
def _flat_to_batch_idx(self, flat_idx: int) -> Tuple[int, ...]:
"""Convert flat batch index to tuple."""
idx = []
for s in reversed(self.batch_shape):
idx.append(flat_idx % s)
flat_idx //= s
return tuple(reversed(idx))
[docs]
def spy(self, *args, **kwargs):
"""Render the sparsity pattern. See :func:`viz.spy`."""
from .viz import spy as _spy
return _spy(self, *args, **kwargs)
[docs]
def eigs(self, *args, **kwargs):
from .linalg import eigs as _impl
return _impl(self, *args, **kwargs)
[docs]
def eigsh(self, *args, **kwargs):
from .linalg import eigsh as _impl
return _impl(self, *args, **kwargs)
[docs]
def svd(self, *args, **kwargs):
from .linalg import svd as _impl
return _impl(self, *args, **kwargs)
[docs]
def condition_number(self, *args, **kwargs):
from .linalg import condition_number as _impl
return _impl(self, *args, **kwargs)
[docs]
def det(self, *args, **kwargs):
from .linalg import det as _impl
return _impl(self, *args, **kwargs)
[docs]
def logdet(self, *args, **kwargs):
from .linalg import logdet as _impl
return _impl(self, *args, **kwargs)
[docs]
def lu(self, *args, **kwargs):
from .linalg import lu as _impl
return _impl(self, *args, **kwargs)
# =========================================================================
# String Representation
# =========================================================================
def __repr__(self) -> str:
parts = [f"SparseTensor(shape={self._shape}"]
if self.is_batched:
parts.append(f"batch={self.batch_shape}")
parts.append(f"sparse={self.sparse_shape}")
if self.is_block:
parts.append(f"block={self.block_shape}")
parts.append(f"nnz={self.nnz}")
parts.append(f"dtype={self.dtype}")
parts.append(f"device={self.device}")
return ", ".join(parts) + ")"
[docs]
def sum(self, *args, **kwargs):
from .reductions import _sum_impl as _impl
return _impl(self, *args, **kwargs)
def _sum_over_sparse(self, *args, **kwargs):
from .reductions import _sum_over_sparse as _impl
return _impl(self, *args, **kwargs)
def _sum_over_batch_block(self, *args, **kwargs):
from .reductions import _sum_over_batch_block as _impl
return _impl(self, *args, **kwargs)
[docs]
def mean(self, *args, **kwargs):
from .reductions import _mean_impl as _impl
return _impl(self, *args, **kwargs)
[docs]
def prod(self, *args, **kwargs):
from .reductions import _prod_impl as _impl
return _impl(self, *args, **kwargs)
[docs]
def max(self, *args, **kwargs):
from .reductions import _max_impl as _impl
return _impl(self, *args, **kwargs)
[docs]
def min(self, *args, **kwargs):
from .reductions import _min_impl as _impl
return _impl(self, *args, **kwargs)
def _normalize_axis(self, *args, **kwargs):
from .reductions import _normalize_axis as _impl
return _impl(self, *args, **kwargs)
def _get_dim_type(self, *args, **kwargs):
from .reductions import _get_dim_type as _impl
return _impl(self, *args, **kwargs)
def _values_axis_for_dim(self, *args, **kwargs):
from .reductions import _values_axis_for_dim as _impl
return _impl(self, *args, **kwargs)
def _apply_elementwise(self, *args, **kwargs):
from .ops import _apply_elementwise as _impl
return _impl(self, *args, **kwargs)
def __add__(self, *args, **kwargs):
from .ops import __add__ as _impl
return _impl(self, *args, **kwargs)
def __radd__(self, *args, **kwargs):
from .ops import __radd__ as _impl
return _impl(self, *args, **kwargs)
def __sub__(self, *args, **kwargs):
from .ops import __sub__ as _impl
return _impl(self, *args, **kwargs)
def __rsub__(self, *args, **kwargs):
from .ops import __rsub__ as _impl
return _impl(self, *args, **kwargs)
def __mul__(self, *args, **kwargs):
from .ops import __mul__ as _impl
return _impl(self, *args, **kwargs)
def __rmul__(self, *args, **kwargs):
from .ops import __rmul__ as _impl
return _impl(self, *args, **kwargs)
def __truediv__(self, *args, **kwargs):
from .ops import __truediv__ as _impl
return _impl(self, *args, **kwargs)
def __rtruediv__(self, *args, **kwargs):
from .ops import __rtruediv__ as _impl
return _impl(self, *args, **kwargs)
def __floordiv__(self, *args, **kwargs):
from .ops import __floordiv__ as _impl
return _impl(self, *args, **kwargs)
def __pow__(self, *args, **kwargs):
from .ops import __pow__ as _impl
return _impl(self, *args, **kwargs)
def __neg__(self, *args, **kwargs):
from .ops import __neg__ as _impl
return _impl(self, *args, **kwargs)
def __pos__(self, *args, **kwargs):
from .ops import __pos__ as _impl
return _impl(self, *args, **kwargs)
def __abs__(self, *args, **kwargs):
from .ops import __abs__ as _impl
return _impl(self, *args, **kwargs)
[docs]
def abs(self, *args, **kwargs):
from .ops import _abs_impl as _impl
return _impl(self, *args, **kwargs)
[docs]
def sqrt(self, *args, **kwargs):
from .ops import _sqrt_impl as _impl
return _impl(self, *args, **kwargs)
[docs]
def square(self, *args, **kwargs):
from .ops import _square_impl as _impl
return _impl(self, *args, **kwargs)
[docs]
def exp(self, *args, **kwargs):
from .ops import _exp_impl as _impl
return _impl(self, *args, **kwargs)
[docs]
def log(self, *args, **kwargs):
from .ops import _log_impl as _impl
return _impl(self, *args, **kwargs)
# =========================================================================
# Persistence (I/O)
# =========================================================================
[docs]
def save(
self,
path: Union[str, "os.PathLike"],
metadata: Optional[Dict[str, str]] = None
) -> None:
"""
Save SparseTensor to safetensors format.
Parameters
----------
path : str or PathLike
Output file path (should end with .safetensors).
metadata : dict, optional
Additional metadata to store.
Example
-------
>>> A = SparseTensor(val, row, col, (100, 100))
>>> A.save("matrix.safetensors")
"""
from ..io import save_sparse
save_sparse(self, path, metadata)
[docs]
@classmethod
def load(
cls,
path: Union[str, "os.PathLike"],
device: Union[str, torch.device] = "cpu"
) -> "SparseTensor":
"""
Load SparseTensor from safetensors format.
Parameters
----------
path : str or PathLike
Input file path.
device : str or torch.device
Device to load tensors to.
Returns
-------
SparseTensor
The loaded sparse tensor.
Example
-------
>>> A = SparseTensor.load("matrix.safetensors", device="cuda")
"""
from ..io import load_sparse
return load_sparse(path, device)
# =============================================================================
# LUFactorization Class
# =============================================================================
[docs]
class LUFactorization:
"""
LU factorization wrapper for efficient repeated solves.
Created by SparseTensor.lu().
Parameters
----------
lu_factor : scipy.sparse.linalg.SuperLU
The SciPy LU factorization object.
shape : Tuple[int, int]
Matrix shape.
dtype : torch.dtype
Data type.
device : torch.device
Device.
Examples
--------
>>> A = SparseTensor(val, row, col, (10, 10))
>>> lu = A.lu()
>>> x1 = lu.solve(b1) # First solve
>>> x2 = lu.solve(b2) # Much faster - reuses factorization
"""
def __init__(self, lu_factor, shape: Tuple[int, int], dtype: torch.dtype, device: torch.device):
self._lu = lu_factor
self._shape = shape
self._dtype = dtype
self._device = device
[docs]
def solve(self, b: torch.Tensor) -> torch.Tensor:
"""
Solve Ax = b using the cached factorization.
Parameters
----------
b : torch.Tensor
Right-hand side vector.
Returns
-------
torch.Tensor
Solution x.
"""
import numpy as np
b_np = b.detach().cpu().numpy()
x_np = self._lu.solve(b_np)
return torch.from_numpy(x_np).to(dtype=self._dtype, device=self._device)
def __repr__(self) -> str:
return f"LUFactorization(shape={self._shape})"
# =============================================================================
# SparseTensorList Class
# =============================================================================