diff --git a/CHANGELOG.md b/CHANGELOG.md index 02cf8b0f6..1d2c0fce3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,15 @@ - Fixed differentiability of `wp.fem.PicQuadrature` w.r.t. positions and measures - Improve error messages for unsupported constructs - Update `wp.matmul()` CPU fallback to use dtype explicitly in `np.matmul()` call +- Add array overwrite detection if `wp.config.verify_autograd_array_access` is True. Array overwrites on the tape may corrupt gradient computation in the backward pass + - Adds `is_read` and `is_write` flags to kernel array args, which are set to `True` if an array arg is determined to be read from and/or written to during compilation + - If a kernel array arg is read from then written to within the same kernel during compilation, a warning is printed + - Adds the `is_read` flag to warp arrays, which is used to track whether an array has been read from in a kernel or recorded func at runtime + - If a warp array is passed to a kernel arg with attribute `is_read = True`, the warp array's `is_read` flag is set to `True` + - If a warp array with attribute `is_read = True` is subsequently passed to a kernel arg with attribute `is_write = True` (write after read overwrite condition), a warning is printed, indicating gradient corruption is possible in the backward pass + - Adds `wp.array.mark_write()` and `wp.array.mark_read()`, which are used to manually mark arrays that are written to or read from in functions recorded with `wp.Tape.record_func()` + - Adds `wp.Tape.reset_array_read_flags()` method, which resets all tape array `is_read` flags to `False`. + - Configures all view-like array methods to inherit `is_read` flag from parent arrays at creation. - Fix ShapeInstancer `__new__()` method (missing instance return and `*args` parameter) - Add support for PEP 563's `from __future__ import annotations`. - Allow passing external arrays/tensors to Warp kernels directly via `__cuda_array_interface__` and `__array_interface__` @@ -40,6 +49,7 @@ ## [1.2.2] - 2024-07-04 - Support for NumPy >= 2.0 +- Fix hashing of replay functions and snippets - Add additional documentation and examples demonstrating `wp.copy()`, `wp.clone()`, and `array.assign()` differentiability - Fix adding `__new__()` methods for all class `__del__()` methods to anticipate when a class instance is created but not instantiated before garbage collection. diff --git a/docs/modules/differentiability.rst b/docs/modules/differentiability.rst index 6bc20bfba..f5c9bc093 100644 --- a/docs/modules/differentiability.rst +++ b/docs/modules/differentiability.rst @@ -772,6 +772,113 @@ In the example above we can see that the array ``c`` does not have its ``require .. note:: Arrays can be labeled with custom names using the ``array_labels`` argument to the ``tape.visualize()`` method. +Array Overwrite Tracking +^^^^^^^^^^^^^^^^^^^^^^^^^ + +It is a common mistake to inadvertently overwrite an array that participates in the computation graph. For example:: + + with tape as wp.Tape(): + + # step 1 + wp.launch(compute_forces, dim=n, inputs=[pos0, vel0], outputs=[force]) + wp.launch(simulate, dim=n, inputs=[pos0, vel0, force], outputs=[pos1, vel1]) + + # step 2 (error, we are overwriting previous forces) + wp.launch(compute_forces, dim=n, inputs=[pos1, vel1], outputs=[force]) + wp.launch(simulate, dim=n, inputs=[pos1, vel1, force], outputs=[pos2, vel2]) + + # compute loss + wp.launch(loss, dim=n, inputs=[pos2]) + + tape.backward(loss) + +Running the tape backwards will incorrectly compute the gradient of the loss with respect to ``pos0`` and ``vel0``, because ``force`` is overwritten in the second simulation step. +The adjoint of ``force`` with respect to ``pos1`` and ``vel1`` will be correct, because the stored value of ``force`` from the forward pass is still correct, but the adjoint of +``force`` with respect to ``pos0`` and ``vel0`` will be incorrect, because the ``force`` value used in this calculation was calculated in step 2, not step 1. The solution is to allocate +two force arrays, ``force0`` and ``force1``, so that we are not overwriting data that participates in the computation graph. + +This sort of problem boils down to a single pattern to be avoided: writing to an array after reading from it. This typically happens over consecutive kernel launches (A), but it might also happen within a single kernel (B). + +A: Inter-Kernel Overwrite:: + + import warp as wp + + @wp.kernel + def square_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)): + tid = wp.tid() + y[tid] = x[tid] * x[tid] + + @wp.kernel + def overwrite_kernel(z: wp.array(dtype=float), x: wp.array(dtype=float)): + tid = wp.tid() + x[tid] = z[tid] + + @wp.kernel + def loss_kernel(x: wp.array(dtype=float), loss: wp.array(dtype=float)): + tid = wp.tid() + wp.atomic_add(loss, 0, x[tid]) + + a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True) + b = wp.zeros_like(a) + c = wp.array(np.array([-1.0, -2.0, -3.0]), dtype=float, requires_grad=True) + loss = wp.zeros(1, dtype=float, requires_grad=True) + + tape = wp.Tape() + with tape: + wp.launch(square_kernel, a.shape, inputs=[a], outputs=[b]) + wp.launch(overwrite_kernel, c.shape, inputs=[c], outputs=[a]) + wp.launch(loss_kernel, a.shape, inputs=[a, loss]) + + tape.backward(loss) + + print(a.grad) + # prints [-2. -4. -6.] instead of [2. 4. 6.] + +B: Intra-Kernel Overwrite:: + + import warp as wp + + @wp.kernel + def readwrite_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)): + tid = wp.tid() + b[tid] = a[tid] * a[tid] + a[tid] = 1.0 + + @wp.kernel + def loss_kernel(x: wp.array(dtype=float), loss: wp.array(dtype=float)): + tid = wp.tid() + wp.atomic_add(loss, 0, x[tid]) + + a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True) + b = wp.zeros_like(a) + loss = wp.zeros(1, dtype=float, requires_grad=True) + + tape = wp.Tape() + with tape: + wp.launch(readwrite_kernel, dim=a.shape, inputs=[a, b]) + wp.launch(loss_kernel, a.shape, inputs=[b, loss]) + + tape.backward(loss) + + print(a.grad) + # prints [2. 2. 2.] instead of [2. 4. 6.] + +If ``wp.config.verify_autograd_array_access = True`` is set, Warp will automatically detect and report array overwrites, covering the above two cases as well as other problematic configurations. +It does so by flagging which kernel array arguments are read from and/or written to in each kernel function during compilation. At runtime, if an array is passed to a kernel argument marked with a read flag, +it is marked as having been read from. Later, if the same array is passed to a kernel argument marked with a write flag, a warning is printed +(recall the pattern we wish to avoid: *write* after *read*). + +.. note:: + Setting ``wp.config.verify_autograd_array_access = True`` will disable kernel caching and force the current module to rebuild. + +.. note:: + Though in-place operations such as ``x[tid] += 1.0`` are technically ``read -> write``, the Warp graph specifically accomodates adjoint accumulation in these cases, so we mark them as write operations. + +.. note:: + This feature does not yet support arrays packed in Warp structs. + +If you make use of :py:meth:`Tape.record_func` in your graph (and so provide your own adjoint callback), be sure to also call :py:meth:`array.mark_write()` and :py:meth:`array.mark_read()`, which will manually mark your arrays as having been written to or read from. + .. _limitations_and_workarounds: Limitations and Workarounds diff --git a/warp/codegen.py b/warp/codegen.py index 184c8d1ec..21a5ff625 100644 --- a/warp/codegen.py +++ b/warp/codegen.py @@ -582,6 +582,14 @@ def __init__(self, label, type, requires_grad=False, constant=None, prefix=True) self.constant = constant self.prefix = prefix + # records whether this Var has been read from in a kernel function (array only) + self.is_read = False + # records whether this Var has been written to in a kernel function (array only) + self.is_write = False + + # used to associate a view array Var with its parent array Var + self.parent = None + def __str__(self): return self.label @@ -624,6 +632,42 @@ def emit(self, prefix: str = "var"): def emit_adj(self): return self.emit("adj") + def mark_read(self): + """Marks this Var as having been read from in a kernel (array only).""" + if not is_array(self.type): + return + + self.is_read = True + + # recursively update all parent states + parent = self.parent + while parent is not None: + parent.is_read = True + parent = parent.parent + + def mark_write(self, **kwargs): + """Marks this Var has having been written to in a kernel (array only).""" + if not is_array(self.type): + return + + # detect if we are writing to an array after reading from it within the same kernel + if self.is_read and warp.config.verify_autograd_array_access: + if "kernel_name" and "filename" and "lineno" in kwargs: + print( + f"Warning: Array passed to argument {self.label} in kernel {kwargs['kernel_name']} at {kwargs['filename']}:{kwargs['lineno']} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass." + ) + else: + print( + f"Warning: Array {self} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass." + ) + self.is_write = True + + # recursively update all parent states + parent = self.parent + while parent is not None: + parent.is_write = True + parent = parent.parent + class Block: # Represents a basic block of instructions, e.g.: list @@ -821,6 +865,11 @@ def __init__( # generate function ssa form and adjoint def build(adj, builder, default_builder_options=None): + # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build + for arg in adj.args: + arg.is_read = False + arg.is_write = False + if adj.skip_build: return @@ -1702,7 +1751,24 @@ def emit_Constant(adj, node): def emit_BinOp(adj, node): # evaluate binary operator arguments + + if warp.config.verify_autograd_array_access: + # array overwrite tracking: in-place operators are a special case + # x[tid] = x[tid] + 1 is a read followed by a write, but we only want to record the write + # so we save the current arg read flags and restore them after lhs eval + is_read_states = [] + for arg in adj.args: + is_read_states.append(arg.is_read) + + # evaluate lhs binary operator argument left = adj.eval(node.left) + + if warp.config.verify_autograd_array_access: + # restore arg read flags + for i, arg in enumerate(adj.args): + arg.is_read = is_read_states[i] + + # evaluate rhs binary operator argument right = adj.eval(node.right) name = builtin_operators[type(node.op)] @@ -2017,7 +2083,18 @@ def emit_Call(adj, node): args = tuple(adj.resolve_arg(x) for x in node.args) kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords} - # add var with value type from the function + if warp.config.verify_autograd_array_access: + # update arg read/write states according to what happens to that arg in the called function + if hasattr(func, "adj"): + for i, arg in enumerate(args): + if func.adj.args[i].is_write: + kernel_name = adj.fun_name + filename = adj.filename + lineno = adj.lineno + adj.fun_lineno + arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno) + if func.adj.args[i].is_read: + arg.mark_read() + out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs) return out @@ -2097,10 +2174,22 @@ def emit_Subscript(adj, node): if len(indices) == target_type.ndim: # handles array loads (where each dimension has an index specified) out = adj.add_builtin_call("address", [target, *indices]) + + if warp.config.verify_autograd_array_access: + target.mark_read() + else: # handles array views (fewer indices than dimensions) out = adj.add_builtin_call("view", [target, *indices]) + if warp.config.verify_autograd_array_access: + # store reference to target Var to propagate downstream read/write state back to root arg Var + out.parent = target + + # view arg inherits target Var's read/write states + out.is_read = target.is_read + out.is_write = target.is_write + else: # handles non-array type indexing, e.g: vec3, mat33, etc out = adj.add_builtin_call("extract", [target, *indices]) @@ -2184,6 +2273,13 @@ def emit_Assign(adj, node): if is_array(target_type): adj.add_builtin_call("array_store", [target, *indices, rhs]) + if warp.config.verify_autograd_array_access: + kernel_name = adj.fun_name + filename = adj.filename + lineno = adj.lineno + adj.fun_lineno + + target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno) + elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type): if is_reference(target.type): attr = adj.add_builtin_call("indexref", [target, *indices]) diff --git a/warp/config.py b/warp/config.py index 4cfb197d0..31d6c616e 100644 --- a/warp/config.py +++ b/warp/config.py @@ -39,6 +39,9 @@ quiet: bool = False """Suppress all output except errors and warnings.""" +verify_autograd_array_access: bool = False +"""print warnings related to array overwrites that may result in incorrect gradients""" + cache_kernels: bool = True """If `True`, kernels that have already been compiled from previous application launches will not be recompiled.""" diff --git a/warp/context.py b/warp/context.py index 0ba6473ef..f69154f94 100644 --- a/warp/context.py +++ b/warp/context.py @@ -1748,7 +1748,13 @@ def load(self, device) -> bool: build_dir = None - if not os.path.exists(binary_path) or not warp.config.cache_kernels: + # we always want to build if binary doesn't exist yet + # and we want to rebuild if we are not caching kernels or if we are tracking array access + if ( + not os.path.exists(binary_path) + or not warp.config.cache_kernels + or warp.config.verify_autograd_array_access + ): builder = ModuleBuilder(self, self.options) # create a temporary (process unique) dir for build outputs before moving to the binary dir @@ -4780,6 +4786,10 @@ def pack_args(args, params, adjoint=False): caller = {"file": frame.f_code.co_filename, "lineno": frame.f_lineno, "func": frame.f_code.co_name} runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device, metadata={"caller": caller}) + # detect illegal inter-kernel read/write access patterns if verification flag is set + if warp.config.verify_autograd_array_access: + runtime.tape.check_kernel_array_access(kernel, fwd_args) + def synchronize(): """Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices @@ -5261,6 +5271,9 @@ def copy( ), arrays=[dest, src], ) + if warp.config.verify_autograd_array_access: + dest.mark_write() + src.mark_read() def adj_copy( diff --git a/warp/tape.py b/warp/tape.py index 1c12b4e1d..8c3cc103e 100644 --- a/warp/tape.py +++ b/warp/tape.py @@ -165,7 +165,7 @@ def record_func(self, backward, arrays): Args: backward (Callable): A callable Python object (can be any function) that will be executed in the backward pass. - arrays (list): A list of arrays that are used by the function for gradient tracking. + arrays (list): A list of arrays that are used by the backward function. The tape keeps track of these to be able to zero their gradients in Tape.zero() """ self.launches.append(backward) @@ -197,6 +197,24 @@ def record_scope_end(self, remove_scope_if_empty=True): else: self.scopes.append((len(self.launches), None, None)) + def check_kernel_array_access(self, kernel, args): + """Detect illegal inter-kernel write after read access patterns during launch capture""" + adj = kernel.adj + kernel_name = adj.fun_name + filename = adj.filename + lineno = adj.fun_lineno + + for i, arg in enumerate(args): + if isinstance(arg, wp.array): + arg_name = adj.args[i].label + + # we check write condition first because we allow (write --> read) within the same kernel + if adj.args[i].is_write: + arg.mark_write(arg_name=arg_name, kernel_name=kernel_name, filename=filename, lineno=lineno) + + if adj.args[i].is_read: + arg.mark_read() + # returns the adjoint of a kernel parameter def get_adjoint(self, a): if not wp.types.is_array(a) and not isinstance(a, wp.codegen.StructInstance): @@ -237,6 +255,8 @@ def reset(self): self.launches = [] self.scopes = [] self.zero() + if wp.config.verify_autograd_array_access: + self.reset_array_read_flags() def zero(self): """ @@ -251,6 +271,14 @@ def zero(self): else: g.zero_() + def reset_array_read_flags(self): + """ + Reset all recorded array read flags to False + """ + for a in self.gradients: + if isinstance(a, wp.array): + a.mark_init() + def visualize( self, filename: str = None, diff --git a/warp/tests/test_overwrite.py b/warp/tests/test_overwrite.py new file mode 100644 index 000000000..26e24a29e --- /dev/null +++ b/warp/tests/test_overwrite.py @@ -0,0 +1,542 @@ +import contextlib +import io +import unittest + +import numpy as np + +import warp as wp +from warp.tests.unittest_utils import * + +# kernels are defined in the global scope, to ensure wp.Kernel objects are not GC'ed in the MGPU case +# kernel args are assigned array modes during codegen, so wp.Kernel objects generated during codegen +# must be preserved for overwrite tracking to function + + +@wp.kernel +def square_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)): + tid = wp.tid() + y[tid] = x[tid] * x[tid] + + +@wp.kernel +def overwrite_kernel_a(z: wp.array(dtype=float), x: wp.array(dtype=float)): + tid = wp.tid() + x[tid] = z[tid] + + +# (kernel READ) -> (kernel WRITE) failure case +def test_kernel_read_kernel_write(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True, device=device) + b = wp.zeros_like(a) + c = wp.array(np.array([-1.0, -2.0, -3.0]), dtype=float, requires_grad=True, device=device) + + tape = wp.Tape() + + with contextlib.redirect_stdout(io.StringIO()) as f: + with tape: + wp.launch(square_kernel, a.shape, inputs=[a], outputs=[b], device=device) + wp.launch(overwrite_kernel_a, c.shape, inputs=[c], outputs=[a], device=device) + + expected = "is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass." + test.assertIn(expected, f.getvalue()) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +@wp.kernel +def double_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)): + tid = wp.tid() + y[tid] = 2.0 * x[tid] + + +@wp.kernel +def triple_kernel(y: wp.array(dtype=float), z: wp.array(dtype=float)): + tid = wp.tid() + z[tid] = 3.0 * y[tid] + + +@wp.kernel +def overwrite_kernel_b(w: wp.array(dtype=float), y: wp.array(dtype=float)): + tid = wp.tid() + y[tid] = 1.0 * w[tid] + + +# (kernel WRITE) -> (kernel READ) -> (kernel WRITE) failure case +def test_kernel_write_kernel_read_kernel_write(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + tape = wp.Tape() + + a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True, device=device) + b = wp.zeros_like(a) + c = wp.zeros_like(a) + d = wp.zeros_like(a) + + with contextlib.redirect_stdout(io.StringIO()) as f: + with tape: + wp.launch(double_kernel, a.shape, inputs=[a], outputs=[b], device=device) + wp.launch(triple_kernel, b.shape, inputs=[b], outputs=[c], device=device) + wp.launch(overwrite_kernel_b, d.shape, inputs=[d], outputs=[b], device=device) + + expected = "is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass." + test.assertIn(expected, f.getvalue()) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +@wp.kernel +def read_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)): + tid = wp.tid() + b[tid] = a[tid] + + +@wp.kernel +def writeread_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), c: wp.array(dtype=float)): + tid = wp.tid() + a[tid] = c[tid] * c[tid] + b[tid] = a[tid] + + +# (kernel READ) -> (kernel WRITE -> READ) failure case +def test_kernel_read_kernel_writeread(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + a = wp.array(np.arange(5), dtype=float, requires_grad=True, device=device) + b = wp.zeros_like(a) + c = wp.zeros_like(a) + d = wp.zeros_like(a) + + tape = wp.Tape() + + with contextlib.redirect_stdout(io.StringIO()) as f: + with tape: + wp.launch(read_kernel, dim=5, inputs=[a, b], device=device) + wp.launch(writeread_kernel, dim=5, inputs=[a, d, c], device=device) + + expected = "is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass." + test.assertIn(expected, f.getvalue()) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +@wp.kernel +def write_kernel(a: wp.array(dtype=float), d: wp.array(dtype=float)): + tid = wp.tid() + a[tid] = d[tid] + + +# (kernel WRITE -> READ) -> (kernel WRITE) failure case +def test_kernel_writeread_kernel_write(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + c = wp.array(np.arange(5), dtype=float, requires_grad=True, device=device) + b = wp.zeros_like(c) + a = wp.zeros_like(c) + d = wp.zeros_like(c) + + tape = wp.Tape() + + with contextlib.redirect_stdout(io.StringIO()) as f: + with tape: + wp.launch(writeread_kernel, dim=5, inputs=[a, b, c], device=device) + wp.launch(write_kernel, dim=5, inputs=[a, d], device=device) + + expected = "is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass." + test.assertIn(expected, f.getvalue()) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +@wp.func +def read_func(a: wp.array(dtype=float), idx: int): + x = a[idx] + return x + + +@wp.func +def read_return_func(b: wp.array(dtype=float), idx: int): + return 1.0, b[idx] + + +@wp.func +def write_func(c: wp.array(dtype=float), idx: int): + c[idx] = 1.0 + + +@wp.func +def main_func(a: wp.array(dtype=float), b: wp.array(dtype=float), c: wp.array(dtype=float), idx: int): + x = read_func(a, idx) + y, z = read_return_func(b, idx) + write_func(c, idx) + return x + y + z + + +@wp.kernel +def func_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float), c: wp.array(dtype=float), d: wp.array(dtype=float)): + tid = wp.tid() + d[tid] = main_func(a, b, c, tid) + + +# test various ways one might write to or read from an array inside warp functions +def test_nested_function_read_write(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + a = wp.zeros(5, dtype=float, requires_grad=True, device=device) + b = wp.zeros_like(a) + c = wp.zeros_like(a) + d = wp.zeros_like(a) + + tape = wp.Tape() + + with tape: + wp.launch(func_kernel, dim=5, inputs=[a, b, c, d], device=device) + + test.assertEqual(a._is_read, True) + test.assertEqual(b._is_read, True) + test.assertEqual(c._is_read, False) + test.assertEqual(d._is_read, False) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +@wp.kernel +def slice_kernel(x: wp.array3d(dtype=float), y: wp.array3d(dtype=float)): + i, j, k = wp.tid() + x_slice = x[i, j] + val = x_slice[k] + + y_slice = y[i, j] + y_slice[k] = val + + +# test updating array r/w mode after indexing +def test_multidimensional_indexing(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + a = np.arange(3, dtype=float) + b = np.tile(a, (3, 3, 1)) + x = wp.array3d(b, dtype=float, requires_grad=True, device=device) + y = wp.zeros_like(x) + + tape = wp.Tape() + + with tape: + wp.launch(slice_kernel, dim=(3, 3, 3), inputs=[x, y], device=device) + + test.assertEqual(x._is_read, True) + test.assertEqual(y._is_read, False) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +@wp.kernel +def inplace_a(x: wp.array(dtype=float)): + tid = wp.tid() + x[tid] += 1.0 + + +@wp.kernel +def inplace_b(x: wp.array(dtype=float), y: wp.array(dtype=float)): + tid = wp.tid() + x[tid] += y[tid] + + +# in-place operators are treated as write +def test_in_place_operators(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + a = wp.zeros(3, dtype=float, requires_grad=True, device=device) + b = wp.zeros_like(a) + + tape = wp.Tape() + + with tape: + wp.launch(inplace_a, dim=3, inputs=[a], device=device) + + test.assertEqual(a._is_read, False) + + tape.reset() + a.zero_() + + with tape: + wp.launch(inplace_b, dim=3, inputs=[a, b], device=device) + + test.assertEqual(a._is_read, False) + test.assertEqual(b._is_read, True) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +def test_views(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + a = wp.zeros((3, 3), dtype=float, requires_grad=True, device=device) + test.assertEqual(a._is_read, False) + + a.mark_write() + + b = a.view(dtype=int) + test.assertEqual(b._is_read, False) + + c = b.flatten() + test.assertEqual(c._is_read, False) + + c.mark_read() + test.assertEqual(a._is_read, True) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +def test_reset(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True, device=device) + b = wp.zeros_like(a) + + tape = wp.Tape() + with tape: + wp.launch(kernel=write_kernel, dim=3, inputs=[b, a], device=device) + + tape.backward(grads={b: wp.ones(3, dtype=float, device=device)}) + + test.assertEqual(a._is_read, True) + test.assertEqual(b._is_read, False) + + tape.reset() + + test.assertEqual(a._is_read, False) + test.assertEqual(b._is_read, False) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +# wp.copy uses wp.record_func. Ensure array modes are propagated correctly. +def test_copy(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True, device=device) + b = wp.zeros_like(a) + + tape = wp.Tape() + + with tape: + wp.copy(b, a) + + test.assertEqual(a._is_read, True) + test.assertEqual(b._is_read, False) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +# wp.matmul uses wp.record_func. Ensure array modes are propagated correctly. +def test_matmul(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + a = wp.ones((3, 3), dtype=float, requires_grad=True, device=device) + b = wp.ones_like(a) + c = wp.ones_like(a) + d = wp.zeros_like(a) + + tape = wp.Tape() + + with tape: + wp.matmul(a, b, c, d) + + test.assertEqual(a._is_read, True) + test.assertEqual(b._is_read, True) + test.assertEqual(c._is_read, True) + test.assertEqual(d._is_read, False) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +# wp.batched_matmul uses wp.record_func. Ensure array modes are propagated correctly. +def test_batched_matmul(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + a = wp.ones((1, 3, 3), dtype=float, requires_grad=True, device=device) + b = wp.ones_like(a) + c = wp.ones_like(a) + d = wp.zeros_like(a) + + tape = wp.Tape() + + with tape: + wp.batched_matmul(a, b, c, d) + + test.assertEqual(a._is_read, True) + test.assertEqual(b._is_read, True) + test.assertEqual(c._is_read, True) + test.assertEqual(d._is_read, False) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +# write after read warning with in-place operators within a kernel +def test_in_place_operators_warning(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + with contextlib.redirect_stdout(io.StringIO()) as f: + + @wp.kernel + def inplace_c(x: wp.array(dtype=float)): + tid = wp.tid() + x[tid] = 1.0 + a = x[tid] + x[tid] += a + + a = wp.zeros(3, dtype=float, requires_grad=True, device=device) + + tape = wp.Tape() + with tape: + wp.launch(inplace_c, dim=3, inputs=[a], device=device) + + expected = "is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass." + test.assertIn(expected, f.getvalue()) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +# (kernel READ -> WRITE) failure case +def test_kernel_readwrite(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + with contextlib.redirect_stdout(io.StringIO()) as f: + + @wp.kernel + def readwrite_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)): + tid = wp.tid() + b[tid] = a[tid] * a[tid] + a[tid] = 1.0 + + a = wp.array(np.arange(5), dtype=float, requires_grad=True, device=device) + b = wp.zeros_like(a) + + tape = wp.Tape() + with tape: + wp.launch(readwrite_kernel, dim=5, inputs=[a, b], device=device) + + expected = "is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass." + test.assertIn(expected, f.getvalue()) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +# (kernel READ -> func WRITE) codegen failure case +def test_kernel_read_func_write(test, device): + saved_verify_autograd_array_access_setting = wp.config.verify_autograd_array_access + try: + wp.config.verify_autograd_array_access = True + + with contextlib.redirect_stdout(io.StringIO()) as f: + + @wp.func + def write_func_2(x: wp.array(dtype=float), idx: int): + x[idx] = 2.0 + + @wp.kernel + def read_kernel_func_write(x: wp.array(dtype=float), y: wp.array(dtype=float)): + tid = wp.tid() + a = x[tid] + write_func_2(x, tid) + y[tid] = a + + a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True, device=device) + b = wp.zeros_like(a) + + tape = wp.Tape() + with tape: + wp.launch(kernel=read_kernel_func_write, dim=3, inputs=[a, b], device=device) + + expected = "written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass." + test.assertIn(expected, f.getvalue()) + + finally: + wp.config.verify_autograd_array_access = saved_verify_autograd_array_access_setting + + +class TestOverwrite(unittest.TestCase): + pass + + +devices = get_test_devices() + +add_function_test(TestOverwrite, "test_kernel_read_kernel_write", test_kernel_read_kernel_write, devices=devices) +add_function_test( + TestOverwrite, + "test_kernel_write_kernel_read_kernel_write", + test_kernel_write_kernel_read_kernel_write, + devices=devices, +) +add_function_test( + TestOverwrite, "test_kernel_read_kernel_writeread", test_kernel_read_kernel_writeread, devices=devices +) +add_function_test( + TestOverwrite, "test_kernel_writeread_kernel_write", test_kernel_writeread_kernel_write, devices=devices +) +add_function_test(TestOverwrite, "test_nested_function_read_write", test_nested_function_read_write, devices=devices) +add_function_test(TestOverwrite, "test_multidimensional_indexing", test_multidimensional_indexing, devices=devices) +add_function_test(TestOverwrite, "test_in_place_operators", test_in_place_operators, devices=devices) +add_function_test(TestOverwrite, "test_views", test_views, devices=devices) +add_function_test(TestOverwrite, "test_reset", test_reset, devices=devices) + +add_function_test(TestOverwrite, "test_copy", test_copy, devices=devices) +add_function_test(TestOverwrite, "test_matmul", test_matmul, devices=devices) +add_function_test(TestOverwrite, "test_batched_matmul", test_batched_matmul, devices=devices) + +# Some warning are only issued during codegen, and codegen only runs on cuda_0 in the MGPU case. +cuda_device = get_cuda_test_devices(mode="basic") + +add_function_test( + TestOverwrite, "test_in_place_operators_warning", test_in_place_operators_warning, devices=cuda_device +) +add_function_test(TestOverwrite, "test_kernel_readwrite", test_kernel_readwrite, devices=cuda_device) +add_function_test(TestOverwrite, "test_kernel_read_func_write", test_kernel_read_func_write, devices=cuda_device) + +if __name__ == "__main__": + wp.build.clear_kernel_cache() + unittest.main(verbosity=2) diff --git a/warp/tests/unittest_suites.py b/warp/tests/unittest_suites.py index 4f4708842..09eec7f0d 100644 --- a/warp/tests/unittest_suites.py +++ b/warp/tests/unittest_suites.py @@ -151,6 +151,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader) from warp.tests.test_noise import TestNoise from warp.tests.test_operators import TestOperators from warp.tests.test_options import TestOptions + from warp.tests.test_overwrite import TestOverwrite from warp.tests.test_peer import TestPeer from warp.tests.test_pinned import TestPinned from warp.tests.test_print import TestPrint @@ -244,6 +245,7 @@ def default_suite(test_loader: unittest.TestLoader = unittest.defaultTestLoader) TestNoise, TestOperators, TestOptions, + TestOverwrite, TestPeer, TestPinned, TestPrint, diff --git a/warp/types.py b/warp/types.py index e7d0524fa..2a0e1be2a 100644 --- a/warp/types.py +++ b/warp/types.py @@ -1701,6 +1701,9 @@ def __init__( else: self._init_annotation(dtype, ndim or 1) + # initialize read flag + self.mark_init() + # initialize gradient, if needed if self.device is not None: if grad is not None: @@ -1712,6 +1715,9 @@ def __init__( if requires_grad: self._alloc_grad() + # reference to other array + self._ref = None + def _init_from_data(self, data, dtype, shape, device, copy, pinned): if not hasattr(data, "__len__"): raise RuntimeError(f"Data must be a sequence or array, got scalar {data}") @@ -2307,6 +2313,33 @@ def vars(self): array._vars = {"shape": warp.codegen.Var("shape", shape_t)} return array._vars + def mark_init(self): + """Resets this array's read flag""" + self._is_read = False + + def mark_read(self): + """Marks this array as having been read from in a kernel or recorded function on the tape.""" + # no additional checks required: it is always safe to set an array to READ + self._is_read = True + + # recursively update all parent arrays + parent = self._ref + while parent is not None: + parent._is_read = True + parent = parent._ref + + def mark_write(self, **kwargs): + """Detect if we are writing to an array that has already been read from""" + if self._is_read: + if "arg_name" and "kernel_name" and "filename" and "lineno" in kwargs: + print( + f"Warning: Array {self} passed to argument {kwargs['arg_name']} in kernel {kwargs['kernel_name']} at {kwargs['filename']}:{kwargs['lineno']} is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass." + ) + else: + print( + f"Warning: Array {self} is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass." + ) + def zero_(self): """Zeroes-out the array entries.""" if self.is_contiguous: @@ -2314,6 +2347,7 @@ def zero_(self): self.device.memset(self.ptr, 0, self.size * type_size_in_bytes(self.dtype)) else: self.fill_(0) + self.mark_init() def fill_(self, value): """Set all array entries to `value` @@ -2388,6 +2422,8 @@ def fill_(self, value): else: warp.context.runtime.core.array_fill_host(carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size) + self.mark_init() + def assign(self, src): """Wraps ``src`` in an :class:`warp.array` if it is not already one and copies the contents to ``self``.""" if is_array(src): @@ -2494,6 +2530,9 @@ def flatten(self): grad=None if self.grad is None else self.grad.flatten(), ) + # transfer read flag + a._is_read = self._is_read + # store back-ref to stop data being destroyed a._ref = self return a @@ -2555,6 +2594,9 @@ def reshape(self, shape): grad=None if self.grad is None else self.grad.reshape(shape), ) + # transfer read flag + a._is_read = self._is_read + # store back-ref to stop data being destroyed a._ref = self return a @@ -2578,6 +2620,9 @@ def view(self, dtype): grad=None if self.grad is None else self.grad.view(dtype), ) + # transfer read flag + a._is_read = self._is_read + a._ref = self return a @@ -2631,6 +2676,9 @@ def transpose(self, axes=None): a.is_transposed = not self.is_transposed + # transfer read flag + a._is_read = self._is_read + a._ref = self return a @@ -3934,6 +3982,11 @@ def matmul( backward=lambda: adj_matmul(a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith), arrays=[a, b, c, d], ) + if warp.config.verify_autograd_array_access: + d.mark_write() + a.mark_read() + b.mark_read() + c.mark_read() # cpu fallback if no cuda devices found if device == "cpu": @@ -4219,6 +4272,11 @@ def batched_matmul( ), arrays=[a, b, c, d], ) + if warp.config.verify_autograd_array_access: + d.mark_write() + a.mark_read() + b.mark_read() + c.mark_read() # cpu fallback if no cuda devices found if device == "cpu":