Skip to content

Commit

Permalink
Merge branch 'lwawrzyniak/fix-gradient-slice' into 'main'
Browse files Browse the repository at this point in the history
Fix slicing of arrays with gradients in kernels

See merge request omniverse/warp!555
  • Loading branch information
nvlukasz committed Jun 6, 2024
2 parents 7858b28 + 9bd3f3a commit a4eabc4
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
- Support headless rendering of `OpenGLRenderer` via `pyglet.options["headless"] = True`
- `RegisteredGLBuffer` can fall back to CPU-bound copying if CUDA/OpenGL interop is not available
- Fix to forward `wp.copy()` params to gradient and adjoint copy function calls.
- Fix slicing of arrays with gradients in kernels

## [1.1.1] - 2024-05-24

Expand Down
15 changes: 12 additions & 3 deletions warp/native/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,10 @@ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i)
assert(i >= 0 && i < src.shape[0]);

array_t<T> a;
a.data = data_at_byte_offset(src, byte_offset(src, i));
size_t offset = byte_offset(src, i);
a.data = data_at_byte_offset(src, offset);
if (src.grad)
a.grad = grad_at_byte_offset(src, offset);
a.shape[0] = src.shape[1];
a.shape[1] = src.shape[2];
a.shape[2] = src.shape[3];
Expand All @@ -509,7 +512,10 @@ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j)
assert(j >= 0 && j < src.shape[1]);

array_t<T> a;
a.data = data_at_byte_offset(src, byte_offset(src, i, j));
size_t offset = byte_offset(src, i, j);
a.data = data_at_byte_offset(src, offset);
if (src.grad)
a.grad = grad_at_byte_offset(src, offset);
a.shape[0] = src.shape[2];
a.shape[1] = src.shape[3];
a.strides[0] = src.strides[2];
Expand All @@ -528,7 +534,10 @@ CUDA_CALLABLE inline array_t<T> view(array_t<T>& src, int i, int j, int k)
assert(k >= 0 && k < src.shape[2]);

array_t<T> a;
a.data = data_at_byte_offset(src, byte_offset(src, i, j, k));
size_t offset = byte_offset(src, i, j, k);
a.data = data_at_byte_offset(src, offset);
if (src.grad)
a.grad = grad_at_byte_offset(src, offset);
a.shape[0] = src.shape[3];
a.strides[0] = src.strides[3];
a.ndim = src.ndim-3;
Expand Down
153 changes: 147 additions & 6 deletions warp/tests/test_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,16 +663,40 @@ def square_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
y[tid] = x[tid] ** 2.0


@wp.kernel
def square_slice_2d_kernel(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float), row_idx: int):
tid = wp.tid()
x_slice = x[row_idx]
y_slice = y[row_idx]
y_slice[tid] = x_slice[tid] ** 2.0


@wp.kernel
def square_slice_3d_1d_kernel(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float), slice_idx: int):
i, j = wp.tid()
x_slice = x[slice_idx]
y_slice = y[slice_idx]
y_slice[i, j] = x_slice[i, j] ** 2.0


@wp.kernel
def square_slice_3d_2d_kernel(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float), slice_i: int, slice_j: int):
tid = wp.tid()
x_slice = x[slice_i, slice_j]
y_slice = y[slice_i, slice_j]
y_slice[tid] = x_slice[tid] ** 2.0


def test_gradient_internal(test, device):
with wp.ScopedDevice(device):
a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=True)
b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=True)

wp.launch(square_kernel, a.size, inputs=[a, b])
wp.launch(square_kernel, dim=a.size, inputs=[a, b])

# use internal gradients (.grad), adj_inputs are None
b.grad = wp.array([1.0, 1.0, 1.0], dtype=float)
wp.launch(square_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[None, None])
wp.launch(square_kernel, dim=a.size, inputs=[a, b], adjoint=True, adj_inputs=[None, None])

assert_np_equal(a.grad.numpy(), np.array([2.0, 4.0, 6.0]))

Expand All @@ -682,12 +706,12 @@ def test_gradient_external(test, device):
a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=False)
b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=False)

wp.launch(square_kernel, a.size, inputs=[a, b])
wp.launch(square_kernel, dim=a.size, inputs=[a, b])

# use external gradients passed in adj_inputs
a_grad = wp.array([0.0, 0.0, 0.0], dtype=float)
b_grad = wp.array([1.0, 1.0, 1.0], dtype=float)
wp.launch(square_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[a_grad, b_grad])
wp.launch(square_kernel, dim=a.size, inputs=[a, b], adjoint=True, adj_inputs=[a_grad, b_grad])

assert_np_equal(a_grad.numpy(), np.array([2.0, 4.0, 6.0]))

Expand All @@ -697,18 +721,132 @@ def test_gradient_precedence(test, device):
a = wp.array([1.0, 2.0, 3.0], dtype=float, requires_grad=True)
b = wp.array([0.0, 0.0, 0.0], dtype=float, requires_grad=True)

wp.launch(square_kernel, a.size, inputs=[a, b])
wp.launch(square_kernel, dim=a.size, inputs=[a, b])

# if both internal and external gradients are present, the external one takes precedence,
# because it's explicitly passed by the user in adj_inputs
a_grad = wp.array([0.0, 0.0, 0.0], dtype=float)
b_grad = wp.array([1.0, 1.0, 1.0], dtype=float)
wp.launch(square_kernel, a.shape[0], inputs=[a, b], adjoint=True, adj_inputs=[a_grad, b_grad])
wp.launch(square_kernel, dim=a.size, inputs=[a, b], adjoint=True, adj_inputs=[a_grad, b_grad])

assert_np_equal(a_grad.numpy(), np.array([2.0, 4.0, 6.0])) # used
assert_np_equal(a.grad.numpy(), np.array([0.0, 0.0, 0.0])) # unused


def test_gradient_slice_2d(test, device):
with wp.ScopedDevice(device):
a = wp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], dtype=float, requires_grad=True)
b = wp.zeros_like(a, requires_grad=False)
b.grad = wp.ones_like(a, requires_grad=False)

wp.launch(square_slice_2d_kernel, dim=a.shape[1], inputs=[a, b, 1])

# use internal gradients (.grad), adj_inputs are None
wp.launch(square_slice_2d_kernel, dim=a.shape[1], inputs=[a, b, 1], adjoint=True, adj_inputs=[None, None])

assert_np_equal(a.grad.numpy(), np.array([[0.0, 0.0], [6.0, 8.0], [0.0, 0.0]]))


def test_gradient_slice_3d_1d(test, device):
with wp.ScopedDevice(device):
data = [
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
[
[11, 12, 13],
[14, 15, 16],
[17, 18, 19],
],
[
[21, 22, 23],
[24, 25, 26],
[27, 28, 29],
],
]
a = wp.array(data, dtype=float, requires_grad=True)
b = wp.zeros_like(a, requires_grad=False)
b.grad = wp.ones_like(a, requires_grad=False)

wp.launch(square_slice_3d_1d_kernel, dim=a.shape[1:], inputs=[a, b, 1])

# use internal gradients (.grad), adj_inputs are None
wp.launch(
square_slice_3d_1d_kernel, dim=a.shape[1:], inputs=[a, b, 1], adjoint=True, adj_inputs=[None, None, 1]
)

expected_grad = [
[
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
],
[
[11 * 2, 12 * 2, 13 * 2],
[14 * 2, 15 * 2, 16 * 2],
[17 * 2, 18 * 2, 19 * 2],
],
[
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
],
]
assert_np_equal(a.grad.numpy(), np.array(expected_grad))


def test_gradient_slice_3d_2d(test, device):
with wp.ScopedDevice(device):
data = [
[
[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
],
[
[11, 12, 13],
[14, 15, 16],
[17, 18, 19],
],
[
[21, 22, 23],
[24, 25, 26],
[27, 28, 29],
],
]
a = wp.array(data, dtype=float, requires_grad=True)
b = wp.zeros_like(a, requires_grad=False)
b.grad = wp.ones_like(a, requires_grad=False)

wp.launch(square_slice_3d_2d_kernel, dim=a.shape[2], inputs=[a, b, 1, 1])

# use internal gradients (.grad), adj_inputs are None
wp.launch(
square_slice_3d_2d_kernel, dim=a.shape[2], inputs=[a, b, 1, 1], adjoint=True, adj_inputs=[None, None, 1, 1]
)

expected_grad = [
[
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
],
[
[0, 0, 0],
[14 * 2, 15 * 2, 16 * 2],
[0, 0, 0],
],
[
[0, 0, 0],
[0, 0, 0],
[0, 0, 0],
],
]
assert_np_equal(a.grad.numpy(), np.array(expected_grad))


devices = get_test_devices()


Expand Down Expand Up @@ -737,6 +875,9 @@ class TestGrad(unittest.TestCase):
add_function_test(TestGrad, "test_gradient_internal", test_gradient_internal, devices=devices)
add_function_test(TestGrad, "test_gradient_external", test_gradient_external, devices=devices)
add_function_test(TestGrad, "test_gradient_precedence", test_gradient_precedence, devices=devices)
add_function_test(TestGrad, "test_gradient_slice_2d", test_gradient_slice_2d, devices=devices)
add_function_test(TestGrad, "test_gradient_slice_3d_1d", test_gradient_slice_3d_1d, devices=devices)
add_function_test(TestGrad, "test_gradient_slice_3d_2d", test_gradient_slice_3d_2d, devices=devices)


if __name__ == "__main__":
Expand Down

0 comments on commit a4eabc4

Please sign in to comment.