From 14a82a52d08dc1733c9cd5d48aba6d2e15f29808 Mon Sep 17 00:00:00 2001 From: Gilles Daviet Date: Tue, 18 Jun 2024 12:43:15 +0200 Subject: [PATCH] Add natural operators, transpose bsr_mv --- warp/sparse.py | 348 +++++++++++++++++++++++++++++++++----- warp/tests/test_sparse.py | 33 ++-- warp/types.py | 3 + 3 files changed, 330 insertions(+), 54 deletions(-) diff --git a/warp/sparse.py b/warp/sparse.py index 56937fe85..5e74ac50c 100644 --- a/warp/sparse.py +++ b/warp/sparse.py @@ -118,6 +118,59 @@ def _nnz_transfer_buf_and_event(self): return self._nnz_buf, ctypes.c_void_p(None) return self._nnz_buf, self._nnz_event.cuda_event + # Overloaded math operators + def __add__(self, y): + return bsr_axpy(y, bsr_copy(self)) + + def __iadd__(self, y): + return bsr_axpy(y, self) + + def __radd__(self, x): + return bsr_axpy(x, bsr_copy(self)) + + def __sub__(self, y): + return bsr_axpy(y, bsr_copy(self), alpha=-1.0) + + def __rsub__(self, x): + return bsr_axpy(x, bsr_copy(self), beta=-1.0) + + def __isub__(self, y): + return bsr_axpy(y, self, alpha=-1.0) + + def __mul__(self, y): + return _BsrScalingExpression(self, y) + + def __rmul__(self, x): + return _BsrScalingExpression(self, x) + + def __imul__(self, y): + return bsr_scale(self, y) + + def __matmul__(self, y): + if isinstance(y, wp.array): + return bsr_mv(self, y) + + return bsr_mm(self, y) + + def __rmatmul__(self, x): + if isinstance(x, wp.array): + return bsr_mv(self, x, transpose=True) + + return bsr_mm(x, self) + + def __imatmul__(self, y): + return bsr_mm(self, y, self) + + def __truediv__(self, y): + return _BsrScalingExpression(self, 1.0 / y) + + def __neg__(self): + return _BsrScalingExpression(self, -1.0) + + def transpose(self): + """Returns a transposed copy of this matrix""" + return bsr_transposed(self) + def bsr_matrix_t(dtype: BlockType): dtype = wp.types.type_to_warp(dtype) @@ -319,6 +372,116 @@ def bsr_set_from_triplets( ) +class _BsrExpression(Generic[_BlockType]): + pass + + +class _BsrScalingExpression(_BsrExpression): + def __init__(self, mat, scale): + self.mat = mat + self.scale = scale + + def eval(self): + return bsr_copy(self) + + @property + def nrow(self) -> int: + return self.mat.nrow + + @property + def ncol(self) -> int: + return self.mat.ncol + + @property + def nnz(self) -> int: + return self.mat.nnz + + @property + def offsets(self) -> wp.array: + return self.mat.offsets + + @property + def columns(self) -> wp.array: + return self.mat.columns + + @property + def scalar_type(self) -> Scalar: + return self.mat.scalar_type + + @property + def block_shape(self) -> Tuple[int, int]: + return self.mat.block_shape + + @property + def block_size(self) -> int: + return self.mat.block_size + + @property + def shape(self) -> Tuple[int, int]: + return self.mat.shape + + @property + def dtype(self) -> type: + return self.mat.dtype + + @property + def device(self) -> wp.context.Device: + return self.mat.device + + # Overloaded math operators + def __add__(self, y): + return bsr_axpy(y, bsr_copy(self.mat), alpha=self.scale) + + def __radd__(self, x): + return bsr_axpy(x, bsr_copy(self.mat), beta=self.scale) + + def __sub__(self, y): + return bsr_axpy(y, bsr_copy(self.mat), alpha=-self.scale) + + def __rsub__(self, x): + return bsr_axpy(x, bsr_copy(self.mat), beta=-self.scale) + + def __mul__(self, y): + return _BsrScalingExpression(self.mat, y * self.scale) + + def __rmul__(self, x): + return _BsrScalingExpression(self.mat, x * self.scale) + + def __matmul__(self, y): + if isinstance(y, wp.array): + return bsr_mv(self.mat, y, alpha=self.scale) + + return bsr_mm(self.mat, y, alpha=self.scale) + + def __rmatmul__(self, x): + if isinstance(x, wp.array): + return bsr_mv(self.mat, x, alpha=self.scale, transpose=True) + + return bsr_mm(x, self.mat, alpha=self.scale) + + def __truediv__(self, y): + return _BsrScalingExpression(self.mat, self.scale / y) + + def __neg__(self): + return _BsrScalingExpression(self.mat, -self.scale) + + def transpose(self): + """Returns a transposed copy of this matrix""" + return _BsrScalingExpression(self.mat.transpose(), self.scale) + + +BsrMatrixOrExpression = Union[BsrMatrix[_BlockType], _BsrExpression[_BlockType]] + + +def _extract_matrix_and_scale(bsr: BsrMatrixOrExpression): + if isinstance(bsr, BsrMatrix): + return bsr, 1.0 + if isinstance(bsr, _BsrScalingExpression): + return bsr.mat, bsr.scale + + raise ValueError("Argument cannot be interpreted as a BsrMatrix") + + @wp.kernel def _bsr_assign_split_offsets( row_factor: int, @@ -341,6 +504,7 @@ def _bsr_assign_split_offsets( @wp.kernel def _bsr_assign_split_blocks( structure_only: wp.bool, + scale: Any, row_factor: int, col_factor: int, dest_row_count: int, @@ -377,7 +541,9 @@ def _bsr_assign_split_blocks( src_base_j = split_col * dest_cols_per_block for i in range(dest_rows_per_block): for j in range(dest_cols_per_block): - dest_values[dest_block, i, j] = src_values[src_block, i + src_base_i, j + src_base_j] + dest_values[dest_block, i, j] = dest_values.dtype( + scale * src_values[src_block, i + src_base_i, j + src_base_j] + ) @wp.kernel @@ -403,6 +569,7 @@ def _bsr_assign_merge_row_col( @wp.kernel def _bsr_assign_merge_blocks( + scale: Any, row_factor: int, col_factor: int, src_row_count: int, @@ -437,14 +604,16 @@ def _bsr_assign_merge_blocks( for i in range(src_rows_per_block): for j in range(src_cols_per_block): - dest_values[dest_block, i + dest_base_i, j + dest_base_j] = src_values[src_block, i, j] + dest_values[dest_block, i + dest_base_i, j + dest_base_j] = dest_values.dtype( + scale * src_values[src_block, i, j] + ) def _bsr_values_as_3d_array(A: BsrMatrix) -> wp.array: if A.block_shape == (1, 1): return A.values.reshape((A.values.shape[0], 1, 1)) - return wp.array( + view = wp.array( data=None, ptr=A.values.ptr, capacity=A.values.capacity, @@ -452,11 +621,13 @@ def _bsr_values_as_3d_array(A: BsrMatrix) -> wp.array: dtype=A.scalar_type, shape=(A.values.shape[0], A.block_shape[0], A.block_shape[1]), ) + view._ref = A.values + return view def bsr_assign( dest: BsrMatrix[BlockType[Rows, Cols, Scalar]], - src: BsrMatrix[BlockType[Rows, Cols, Any]], + src: BsrMatrixOrExpression[BlockType[Any, Any, Any]], structure_only: bool = False, ): """Copies the content of the `src` BSR matrix to `dest`. @@ -466,6 +637,8 @@ def bsr_assign( casting if the two matrices use distinct scalar types. """ + src, src_scale = _extract_matrix_and_scale(src) + if dest.values.device != src.values.device: raise ValueError("Source and destination matrices must reside on the same device") @@ -484,6 +657,7 @@ def bsr_assign( if not structure_only: warp.utils.array_cast(out_array=dest.values, in_array=src.values, count=nnz_alloc) + bsr_scale(dest, src_scale) elif src.block_shape[0] >= dest.block_shape[0] and src.block_shape[1] >= dest.block_shape[1]: # Split blocks @@ -517,6 +691,7 @@ def bsr_assign( device=dest.device, inputs=[ wp.bool(structure_only), + src.scalar_type(src_scale), row_factor, col_factor, dest.nrow, @@ -597,6 +772,7 @@ def bsr_assign( dim=src.nnz, device=dest.device, inputs=[ + src.scalar_type(src_scale), row_factor, col_factor, src.nrow, @@ -614,7 +790,7 @@ def bsr_assign( def bsr_copy( - A: BsrMatrix, + A: BsrMatrixOrExpression, scalar_type: Optional[Scalar] = None, block_shape: Optional[Tuple[int, int]] = None, structure_only: bool = False, @@ -643,7 +819,7 @@ def bsr_copy( rows_of_blocks=A.nrow, cols_of_blocks=A.ncol, block_type=block_type, - device=A.values.device, + device=A.device, ) bsr_assign(dest=copy, src=A) return copy @@ -651,10 +827,12 @@ def bsr_copy( def bsr_set_transpose( dest: BsrMatrix[BlockType[Cols, Rows, Scalar]], - src: BsrMatrix[BlockType[Rows, Cols, Scalar]], + src: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]], ): """Assigns the transposed matrix `src` to matrix `dest`""" + src, src_scale = _extract_matrix_and_scale(src) + if dest.values.device != src.values.device: raise ValueError("All arguments must reside on the same device") @@ -710,9 +888,10 @@ def bsr_set_transpose( ) dest.copy_nnz_async() + bsr_scale(dest, src_scale) -def bsr_transposed(A: BsrMatrix): +def bsr_transposed(A: BsrMatrixOrExpression): """Returns a copy of the transposed matrix `A`""" if A.block_shape == (1, 1): @@ -724,7 +903,7 @@ def bsr_transposed(A: BsrMatrix): rows_of_blocks=A.ncol, cols_of_blocks=A.nrow, block_type=block_type, - device=A.values.device, + device=A.device, ) bsr_set_transpose(dest=transposed, src=A) return transposed @@ -732,6 +911,7 @@ def bsr_transposed(A: BsrMatrix): @wp.kernel def _bsr_get_diag_kernel( + scale: Any, A_offsets: wp.array(dtype=int), A_columns: wp.array(dtype=int), A_values: wp.array(dtype=Any), @@ -744,10 +924,10 @@ def _bsr_get_diag_kernel( diag = wp.lower_bound(A_columns, beg, end, row) if diag < end: if A_columns[diag] == row: - out[row] = A_values[diag] + out[row] = scale * A_values[diag] -def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]": +def bsr_get_diag(A: BsrMatrixOrExpression[BlockType], out: "Optional[Array[BlockType]]" = None) -> "Array[BlockType]": """Returns the array of blocks that constitute the diagonal of a sparse matrix. Args: @@ -755,6 +935,8 @@ def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = N out: if provided, the array into which to store the diagonal blocks """ + A, scale = _extract_matrix_and_scale(A) + dim = min(A.nrow, A.ncol) if out is None: @@ -771,7 +953,7 @@ def bsr_get_diag(A: BsrMatrix[_BlockType], out: "Optional[Array[BlockType]]" = N kernel=_bsr_get_diag_kernel, dim=dim, device=A.values.device, - inputs=[A.offsets, A.columns, A.values, out], + inputs=[A.scalar_type(scale), A.offsets, A.columns, A.values, out], ) return out @@ -970,11 +1152,14 @@ def _bsr_scale_kernel( values[wp.tid()] = alpha * values[wp.tid()] -def bsr_scale(x: BsrMatrix, alpha: Scalar) -> BsrMatrix: +def bsr_scale(x: BsrMatrixOrExpression, alpha: Scalar) -> BsrMatrix: """ Performs the operation ``x := alpha * x`` on BSR matrix `x` and returns `x` """ + x, scale = _extract_matrix_and_scale(x) + alpha *= scale + if alpha != 1.0 and x.nnz > 0: if alpha == 0.0: bsr_set_zero(x) @@ -1056,7 +1241,7 @@ def _allocate(self, device, y: BsrMatrix, sum_nnz: int): def bsr_axpy( - x: BsrMatrix[BlockType[Rows, Cols, Scalar]], + x: BsrMatrixOrExpression, y: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None, alpha: Scalar = 1.0, beta: Scalar = 1.0, @@ -1075,6 +1260,9 @@ def bsr_axpy( work_arrays: In most cases this function will require the use of temporary storage; this storage can be reused across calls by passing an instance of :class:`bsr_axpy_work_arrays` in `work_arrays`. """ + x, x_scale = _extract_matrix_and_scale(x) + alpha *= x_scale + if y is None: # If not output matrix is provided, allocate it for convenience y = bsr_zeros(x.nrow, x.ncol, block_type=x.values.dtype, device=x.values.device) @@ -1338,8 +1526,8 @@ def _allocate_stage_2(self, mm_nnz: int): def bsr_mm( - x: BsrMatrix[BlockType[Rows, Any, Scalar]], - y: BsrMatrix[BlockType[Any, Cols, Scalar]], + x: BsrMatrixOrExpression[BlockType[Rows, Any, Scalar]], + y: BsrMatrixOrExpression[BlockType[Any, Cols, Scalar]], z: Optional[BsrMatrix[BlockType[Rows, Cols, Scalar]]] = None, alpha: Scalar = 1.0, beta: Scalar = 0.0, @@ -1364,6 +1552,11 @@ def bsr_mm( This is neccessary for `bsr_mm` to be captured in a CUDA graph. """ + x, x_scale = _extract_matrix_and_scale(x) + alpha *= x_scale + y, y_scale = _extract_matrix_and_scale(y) + alpha *= y_scale + if z is None: # If not output matrix is provided, allocate it for convenience z_block_shape = (x.block_shape[0], y.block_shape[1]) @@ -1602,12 +1795,57 @@ def _bsr_mv_kernel( y[row] = v +@wp.kernel +def _bsr_mv_transpose_kernel( + alpha: Any, + A_offsets: wp.array(dtype=int), + A_columns: wp.array(dtype=int), + A_values: wp.array(dtype=Any), + x: wp.array(dtype=Any), + y: wp.array(dtype=Any), +): + row = wp.tid() + beg = A_offsets[row] + end = A_offsets[row + 1] + xr = alpha * x[row] + for block in range(beg, end): + v = wp.transpose(A_values[block]) * xr + wp.atomic_add(y, A_columns[block], v) + + +def _bsr_mv_as_vec_array(array: wp.array) -> wp.array: + if array.ndim == 1: + return array + + if array.ndim > 2: + raise ValueError(f"Incompatible array number of dimensions {array.ndim}") + + if not array.is_contiguous: + raise ValueError("2d array must be contiguous") + + def vec_view(array): + return wp.array( + data=None, + ptr=array.ptr, + capacity=array.capacity, + device=array.device, + dtype=wp.vec(length=array.shape[1], dtype=array.dtype), + shape=array.shape[0], + grad=None if array.grad is None else vec_view(array.grad), + ) + + view = vec_view(array) + view._ref = array + return view + + def bsr_mv( - A: BsrMatrix[BlockType[Rows, Cols, Scalar]], + A: BsrMatrixOrExpression[BlockType[Rows, Cols, Scalar]], x: "Array[Vector[Cols, Scalar] | Scalar]", y: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None, alpha: Scalar = 1.0, beta: Scalar = 0.0, + transpose: bool = False, work_buffer: Optional["Array[Vector[Rows, Scalar] | Scalar]"] = None, ) -> "Array[Vector[Rows, Scalar] | Scalar]": """ @@ -1621,16 +1859,26 @@ def bsr_mv( y: Mutable left-hand-side. If `y` is not provided, it will be allocated and treated as zero. alpha: Uniform scaling factor for `x`. If zero, `x` will not be read and may be left uninitialized. beta: Uniform scaling factor for `y`. If zero, `y` will not be read and may be left uninitialized. + transpose: If ``True``, use the tranpose of the matrix `A`. In this case the result is **non-deterministic**. work_buffer: Temporary storage is required if and only if `x` and `y` are the same vector. If provided the `work_buffer` array will be used for this purpose, otherwise a temporary allocation will be performed. """ + A, A_scale = _extract_matrix_and_scale(A) + alpha *= A_scale + + if transpose: + block_shape = A.block_shape[1], A.block_shape[0] + nrow, ncol = A.ncol, A.nrow + else: + block_shape = A.block_shape + nrow, ncol = A.nrow, A.ncol + if y is None: # If no output array is provided, allocate one for convenience - y_vec_len = A.block_shape[0] + y_vec_len = block_shape[0] y_dtype = A.scalar_type if y_vec_len == 1 else wp.vec(length=y_vec_len, dtype=A.scalar_type) - y = wp.empty(shape=(A.nrow,), device=A.values.device, dtype=y_dtype) - y.zero_() + y = wp.empty(shape=(nrow,), device=A.values.device, dtype=y_dtype) beta = 0.0 if not isinstance(alpha, A.scalar_type): @@ -1641,11 +1889,15 @@ def bsr_mv( if A.values.device != x.device or A.values.device != y.device: raise ValueError("A, x and y must reside on the same device") - if x.shape[0] != A.ncol: + if x.shape[0] != ncol: raise ValueError("Number of columns of A must match number of rows of x") - if y.shape[0] != A.nrow: + if y.shape[0] != nrow: raise ValueError("Number of rows of A must match number of rows of y") + # View 2d arrays as arrays of vecs + x = _bsr_mv_as_vec_array(x) + y = _bsr_mv_as_vec_array(y) + if x == y: # Aliasing case, need temporary storage if work_buffer is None: @@ -1661,25 +1913,39 @@ def bsr_mv( # Promote scalar vectors to length-1 vecs and conversely if warp.types.type_is_matrix(A.values.dtype): - if A.block_shape[0] == 1: - if y.dtype == A.scalar_type: - y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type)) - if A.block_shape[1] == 1: - if x.dtype == A.scalar_type: - x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type)) + if block_shape[0] == 1 and y.dtype == A.scalar_type: + y = y.view(dtype=wp.vec(length=1, dtype=A.scalar_type)) + if block_shape[1] == 1 and x.dtype == A.scalar_type: + x = x.view(dtype=wp.vec(length=1, dtype=A.scalar_type)) else: - if A.block_shape[0] == 1: - if y.dtype != A.scalar_type: - y = y.view(dtype=A.scalar_type) - if A.block_shape[1] == 1: - if x.dtype != A.scalar_type: - x = x.view(dtype=A.scalar_type) - - wp.launch( - kernel=_bsr_mv_kernel, - device=A.values.device, - dim=A.nrow, - inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y], - ) + if block_shape[0] == 1 and y.dtype != A.scalar_type: + y = y.view(dtype=A.scalar_type) + if block_shape[1] == 1 and x.dtype != A.scalar_type: + x = x.view(dtype=A.scalar_type) + + if transpose: + if beta.value == 0.0: + y.zero_() + elif beta.value != 1.0: + wp.launch( + kernel=_bsr_scale_kernel, + device=y.device, + dim=y.shape[0], + inputs=[beta, y], + ) + if alpha.value != 0.0: + wp.launch( + kernel=_bsr_mv_transpose_kernel, + device=A.values.device, + dim=ncol, + inputs=[alpha, A.offsets, A.columns, A.values, x, y], + ) + else: + wp.launch( + kernel=_bsr_mv_kernel, + device=A.values.device, + dim=nrow, + inputs=[alpha, A.offsets, A.columns, A.values, x, beta, y], + ) return y diff --git a/warp/tests/test_sparse.py b/warp/tests/test_sparse.py index d63f32c2d..06f3a8c59 100644 --- a/warp/tests/test_sparse.py +++ b/warp/tests/test_sparse.py @@ -158,6 +158,9 @@ def test_bsr_get_set_diag(test, device): diag = bsr_get_diag(diag_bsr) assert_np_equal(diag_scalar_np, diag.numpy(), tol=0.000001) + diag = bsr_get_diag(2.0 * diag_bsr) + assert_np_equal(2.0 * diag_scalar_np, diag.numpy(), tol=0.000001) + # Uniform block diagonal with test.assertRaisesRegex(ValueError, "BsrMatrix block type must be either warp matrix or scalar"): @@ -249,14 +252,11 @@ def test_bsr_transpose(test, device): bsr = bsr_zeros(nrow, ncol, wp.types.matrix(shape=block_shape, dtype=scalar_type), device=device) bsr_set_from_triplets(bsr, rows, cols, vals) - ref = np.transpose(_bsr_to_dense(bsr)) + ref = 2.0 * np.transpose(_bsr_to_dense(bsr)) - bsr_transposed = bsr_zeros( - ncol, nrow, wp.types.matrix(shape=block_shape[::-1], dtype=scalar_type), device=device - ) - bsr_set_transpose(dest=bsr_transposed, src=bsr) + bsr_transposed = (2.0 * bsr).transpose() - res = _bsr_to_dense(bsr_transposed) + res = _bsr_to_dense(bsr_transposed.eval()) assert_np_equal(res, ref, 0.0001) if block_shape[0] != block_shape[-1]: @@ -297,17 +297,14 @@ def test_bsr_axpy(test, device): work_arrays = bsr_axpy_work_arrays() for alpha, beta in zip(alphas, betas): ref = alpha * _bsr_to_dense(x) + beta * _bsr_to_dense(y) - if beta == 0.0: - y = bsr_axpy(x, alpha=alpha, beta=beta, work_arrays=work_arrays) - else: - bsr_axpy(x, y, alpha, beta, work_arrays=work_arrays) + bsr_axpy(x, y, alpha, beta, work_arrays=work_arrays) res = _bsr_to_dense(y) assert_np_equal(res, ref, 0.0001) # test aliasing ref = 3.0 * _bsr_to_dense(y) - bsr_axpy(y, y, alpha=1.0, beta=2.0) + y += y * 2.0 res = _bsr_to_dense(y) assert_np_equal(res, ref, 0.0001) @@ -337,7 +334,7 @@ def test_bsr_mm(test, device): nnz = 6 - alphas = [-1.0, 0.0, 1.0] + alphas = [-1.0, 0.0, 2.0] betas = [2.0, -1.0, 0.0] x_rows = wp.array(rng.integers(0, high=x_nrow, size=nnz, dtype=int), dtype=int, device=device) @@ -378,6 +375,10 @@ def test_bsr_mm(test, device): bsr_mm(x, y, z, alpha, beta, work_arrays=work_arrays, reuse_topology=True) assert_np_equal(res, ref, 0.0001) + # using overloaded operators + x = (alpha * x) @ y + assert_np_equal(res, ref, 0.0001) + # test aliasing of matrix arguments # x = alpha * z * x + beta * x alpha, beta = alphas[0], betas[0] @@ -446,13 +447,19 @@ def test_bsr_mv(test, device): for alpha, beta in zip(alphas, betas): ref = alpha * _bsr_to_dense(A) @ x.numpy().flatten() + beta * y.numpy().flatten() if beta == 0.0: - y = bsr_mv(A, x, alpha=alpha, beta=beta, work_buffer=work_buffer) + y = A @ x else: bsr_mv(A, x, y, alpha, beta, work_buffer=work_buffer) res = y.numpy().flatten() assert_np_equal(res, ref, 0.0001) + # test tranposed product + ref = alpha * y.numpy().flatten() @ _bsr_to_dense(A) + x = y @ (A * alpha) + res = x.numpy().flatten() + assert_np_equal(res, ref, 0.0001) + # test aliasing alpha, beta = alphas[0], betas[0] AAt = bsr_mm(A, bsr_transposed(A)) diff --git a/warp/types.py b/warp/types.py index 0eadfe005..46c330864 100644 --- a/warp/types.py +++ b/warp/types.py @@ -2164,6 +2164,9 @@ def __matmul__(self, other): """ Enables A @ B syntax for matrix multiplication """ + if not is_array(other): + return NotImplemented + if self.ndim != 2 or other.ndim != 2: raise RuntimeError( "A has dim = {}, B has dim = {}. If multiplying with @, A and B must have dim = 2.".format(