Skip to content

Commit

Permalink
Fix aliasing check, device for native funcs
Browse files Browse the repository at this point in the history
  • Loading branch information
gdaviet committed Jul 1, 2024
1 parent 14a82a5 commit ee53026
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 81 deletions.
162 changes: 82 additions & 80 deletions warp/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,21 +355,22 @@ def bsr_set_from_triplets(

nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event()

native_func(
dest.block_shape[0],
dest.block_shape[1],
dest.nrow,
nnz,
ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(values.ptr, ctypes.c_void_p),
prune_numerical_zeros,
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest.values.ptr, ctypes.c_void_p),
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
nnz_event,
)
with wp.ScopedDevice(device):
native_func(
dest.block_shape[0],
dest.block_shape[1],
dest.nrow,
nnz,
ctypes.cast(rows.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(columns.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(values.ptr, ctypes.c_void_p),
prune_numerical_zeros,
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest.values.ptr, ctypes.c_void_p),
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
nnz_event,
)


class _BsrExpression(Generic[_BlockType]):
Expand Down Expand Up @@ -613,16 +614,14 @@ def _bsr_values_as_3d_array(A: BsrMatrix) -> wp.array:
if A.block_shape == (1, 1):
return A.values.reshape((A.values.shape[0], 1, 1))

view = wp.array(
return wp.array(
data=None,
ptr=A.values.ptr,
capacity=A.values.capacity,
device=A.device,
dtype=A.scalar_type,
shape=(A.values.shape[0], A.block_shape[0], A.block_shape[1]),
)
view._ref = A.values
return view


def bsr_assign(
Expand Down Expand Up @@ -748,21 +747,22 @@ def bsr_assign(
native_func = runtime.core.bsr_matrix_from_triplets_float_device

nnz_buf, nnz_event = dest._nnz_transfer_buf_and_event()
native_func(
dest.block_shape[0],
dest.block_shape[1],
dest.nrow,
dest.nnz,
ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
False,
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
nnz_event,
)
with wp.ScopedDevice(dest.device):
native_func(
dest.block_shape[0],
dest.block_shape[1],
dest.nrow,
dest.nnz,
ctypes.cast(dest_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
False,
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
nnz_event,
)

# merge block values
if not structure_only:
Expand Down Expand Up @@ -849,9 +849,7 @@ def bsr_set_transpose(
dest.ncol = src.nrow

if nnz == 0:
dest.nnz = 0
dest.offsets.zero_()
dest.launch_nnz_transfer(0)
bsr_set_zero(dest)
return

# Increase dest array sizes if needed
Expand All @@ -873,19 +871,20 @@ def bsr_set_transpose(
if not native_func:
raise NotImplementedError(f"bsr_set_transpose not implemented for scalar type {dest.scalar_type}")

native_func(
src.block_shape[0],
src.block_shape[1],
src.nrow,
src.ncol,
nnz,
ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(src.values.ptr, ctypes.c_void_p),
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest.values.ptr, ctypes.c_void_p),
)
with wp.ScopedDevice(dest.device):
native_func(
src.block_shape[0],
src.block_shape[1],
src.nrow,
src.ncol,
nnz,
ctypes.cast(src.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(src.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(src.values.ptr, ctypes.c_void_p),
ctypes.cast(dest.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(dest.values.ptr, ctypes.c_void_p),
)

dest.copy_nnz_async()
bsr_scale(dest, src_scale)
Expand Down Expand Up @@ -1339,21 +1338,22 @@ def bsr_axpy(
old_y_nnz = y_nnz
nnz_buf, nnz_event = y._nnz_transfer_buf_and_event()

native_func(
y.block_shape[0],
y.block_shape[1],
y.nrow,
sum_nnz,
ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
False,
ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
nnz_event,
)
with wp.ScopedDevice(y.device):
native_func(
y.block_shape[0],
y.block_shape[1],
y.nrow,
sum_nnz,
ctypes.cast(work_arrays._sum_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(work_arrays._sum_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
False,
ctypes.cast(y.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(y.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
nnz_event,
)

_bsr_ensure_fits(y, nnz=sum_nnz)

Expand Down Expand Up @@ -1688,21 +1688,23 @@ def bsr_mm(
native_func = runtime.core.bsr_matrix_from_triplets_float_device

nnz_buf, nnz_event = z._nnz_transfer_buf_and_event()
native_func(
z.block_shape[0],
z.block_shape[1],
z.nrow,
mm_nnz,
ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
False,
ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
nnz_event,
)

with wp.ScopedDevice(z.device):
native_func(
z.block_shape[0],
z.block_shape[1],
z.nrow,
mm_nnz,
ctypes.cast(work_arrays._mm_rows.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(work_arrays._mm_cols.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
False,
ctypes.cast(z.offsets.ptr, ctypes.POINTER(ctypes.c_int32)),
ctypes.cast(z.columns.ptr, ctypes.POINTER(ctypes.c_int32)),
0,
ctypes.cast(nnz_buf.ptr, ctypes.POINTER(ctypes.c_int32)),
nnz_event,
)

# Resize z to fit mm result if necessary

Expand Down Expand Up @@ -1898,7 +1900,7 @@ def bsr_mv(
x = _bsr_mv_as_vec_array(x)
y = _bsr_mv_as_vec_array(y)

if x == y:
if x.ptr == y.ptr:
# Aliasing case, need temporary storage
if work_buffer is None:
work_buffer = wp.empty_like(y)
Expand Down
4 changes: 3 additions & 1 deletion warp/tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,10 @@ def test_bsr_mv(test, device):
assert_np_equal(res, ref, 0.0001)

# test aliasing
alpha, beta = alphas[0], betas[0]
AAt = bsr_mm(A, bsr_transposed(A))
assert_np_equal(_bsr_to_dense(AAt), _bsr_to_dense(A) @ _bsr_to_dense(A).T, 0.0001)

alpha, beta = alphas[0], betas[0]
ref = alpha * _bsr_to_dense(AAt) @ y.numpy().flatten() + beta * y.numpy().flatten()
bsr_mv(AAt, y, y, alpha, beta)
res = y.numpy().flatten()
Expand Down

0 comments on commit ee53026

Please sign in to comment.