diff --git a/warp/sparse.py b/warp/sparse.py index 5e74ac50c..93ab48c40 100644 --- a/warp/sparse.py +++ b/warp/sparse.py @@ -355,21 +355,22 @@ def bsr_set_from_triplets( nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event() - native_func( - dest.block_shape[0], - dest.block_shape[1], - dest.nrow, - nnz, - ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(values.ptr, ctypes.c_void_p), - prune_numerical_zeros, - ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(dest.values.ptr, ctypes.c_void_p), - ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)), - nnz_event, - ) + with wp.ScopedDevice(device): + native_func( + dest.block_shape[0], + dest.block_shape[1], + dest.nrow, + nnz, + ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(values.ptr, ctypes.c_void_p), + prune_numerical_zeros, + ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(dest.values.ptr, ctypes.c_void_p), + ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)), + nnz_event, + ) class _BsrExpression(Generic[_BlockType]): @@ -613,7 +614,7 @@ def _bsr_values_as_3d_array(A: BsrMatrix) -> wp.array: if A.block_shape == (1, 1): return A.values.reshape((A.values.shape[0], 1, 1)) - view = wp.array( + return wp.array( data=None, ptr=A.values.ptr, capacity=A.values.capacity, @@ -621,8 +622,6 @@ def _bsr_values_as_3d_array(A: BsrMatrix) -> wp.array: dtype=A.scalar_type, shape=(A.values.shape[0], A.block_shape[0], A.block_shape[1]), ) - view._ref = A.values - return view def bsr_assign( @@ -748,21 +747,22 @@ def bsr_assign( native_func = runtime.core.bsr_matrix_from_triplets_float_device nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event() - native_func( - dest.block_shape[0], - dest.block_shape[1], - dest.nrow, - dest.nnz, - ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)), - 0, - False, - ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)), - 0, - ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)), - nnz_event, - ) + with wp.ScopedDevice(dest.device): + native_func( + dest.block_shape[0], + dest.block_shape[1], + dest.nrow, + dest.nnz, + ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)), + 0, + False, + ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)), + 0, + ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)), + nnz_event, + ) # merge block values if not structure_only: @@ -849,9 +849,7 @@ def bsr_set_transpose( dest.ncol = src.nrow if nnz == 0: - dest.nnz = 0 - dest.offsets.zero_() - dest.launch_nnz_transfer(0) + bsr_set_zero(dest) return # Increase dest array sizes if needed @@ -873,19 +871,20 @@ def bsr_set_transpose( if not native_func: raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}") - native_func( - src.block_shape[0], - src.block_shape[1], - src.nrow, - src.ncol, - nnz, - ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(src.values.ptr, ctypes.c_void_p), - ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(dest.values.ptr, ctypes.c_void_p), - ) + with wp.ScopedDevice(dest.device): + native_func( + src.block_shape[0], + src.block_shape[1], + src.nrow, + src.ncol, + nnz, + ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(src.values.ptr, ctypes.c_void_p), + ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(dest.values.ptr, ctypes.c_void_p), + ) dest.copy_nnz_async() bsr_scale(dest, src_scale) @@ -1339,21 +1338,22 @@ def bsr_axpy( old_y_nnz = y_nnz nnz_buf, nnz_event = y._nnz_transfer_buf_and_event() - native_func( - y.block_shape[0], - y.block_shape[1], - y.nrow, - sum_nnz, - ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)), - 0, - False, - ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)), - 0, - ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)), - nnz_event, - ) + with wp.ScopedDevice(y.device): + native_func( + y.block_shape[0], + y.block_shape[1], + y.nrow, + sum_nnz, + ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)), + 0, + False, + ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)), + 0, + ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)), + nnz_event, + ) _bsr_ensure_fits(y, nnz=sum_nnz) @@ -1688,21 +1688,23 @@ def bsr_mm( native_func = runtime.core.bsr_matrix_from_triplets_float_device nnz_buf, nnz_event = z._nnz_transfer_buf_and_event() - native_func( - z.block_shape[0], - z.block_shape[1], - z.nrow, - mm_nnz, - ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)), - 0, - False, - ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), - ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)), - 0, - ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)), - nnz_event, - ) + + with wp.ScopedDevice(z.device): + native_func( + z.block_shape[0], + z.block_shape[1], + z.nrow, + mm_nnz, + ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)), + 0, + False, + ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)), + ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)), + 0, + ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)), + nnz_event, + ) # Resize z to fit mm result if necessary @@ -1898,7 +1900,7 @@ def bsr_mv( x = _bsr_mv_as_vec_array(x) y = _bsr_mv_as_vec_array(y) - if x == y: + if x.ptr == y.ptr: # Aliasing case, need temporary storage if work_buffer is None: work_buffer = wp.empty_like(y) diff --git a/warp/tests/test_sparse.py b/warp/tests/test_sparse.py index 06f3a8c59..513a9719e 100644 --- a/warp/tests/test_sparse.py +++ b/warp/tests/test_sparse.py @@ -461,8 +461,10 @@ def test_bsr_mv(test, device): assert_np_equal(res, ref, 0.0001) # test aliasing - alpha, beta = alphas[0], betas[0] AAt = bsr_mm(A, bsr_transposed(A)) + assert_np_equal(_bsr_to_dense(AAt), _bsr_to_dense(A) @ _bsr_to_dense(A).T, 0.0001) + + alpha, beta = alphas[0], betas[0] ref = alpha * _bsr_to_dense(AAt) @ y.numpy().flatten() + beta * y.numpy().flatten() bsr_mv(AAt, y, y, alpha, beta) res = y.numpy().flatten()