Skip to content

Commit

Permalink
Merge branch 'array_overwrite' into 'main'
Browse files Browse the repository at this point in the history
Array Overwrite Notifications

See merge request omniverse/warp!457
  • Loading branch information
mmacklin committed Jul 23, 2024
2 parents 9cf4c36 + d2be8ea commit c1915c9
Show file tree
Hide file tree
Showing 9 changed files with 862 additions and 3 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,15 @@
- Fixed differentiability of `wp.fem.PicQuadrature` w.r.t. positions and measures
- Improve error messages for unsupported constructs
- Update `wp.matmul()` CPU fallback to use dtype explicitly in `np.matmul()` call
- Add array overwrite detection if `wp.config.verify_autograd_array_access` is True. Array overwrites on the tape may corrupt gradient computation in the backward pass
- Adds `is_read` and `is_write` flags to kernel array args, which are set to `True` if an array arg is determined to be read from and/or written to during compilation
- If a kernel array arg is read from then written to within the same kernel during compilation, a warning is printed
- Adds the `is_read` flag to warp arrays, which is used to track whether an array has been read from in a kernel or recorded func at runtime
- If a warp array is passed to a kernel arg with attribute `is_read = True`, the warp array's `is_read` flag is set to `True`
- If a warp array with attribute `is_read = True` is subsequently passed to a kernel arg with attribute `is_write = True` (write after read overwrite condition), a warning is printed, indicating gradient corruption is possible in the backward pass
- Adds `wp.array.mark_write()` and `wp.array.mark_read()`, which are used to manually mark arrays that are written to or read from in functions recorded with `wp.Tape.record_func()`
- Adds `wp.Tape.reset_array_read_flags()` method, which resets all tape array `is_read` flags to `False`.
- Configures all view-like array methods to inherit `is_read` flag from parent arrays at creation.
- Fix ShapeInstancer `__new__()` method (missing instance return and `*args` parameter)
- Add support for PEP 563's `from __future__ import annotations`.
- Allow passing external arrays/tensors to Warp kernels directly via `__cuda_array_interface__` and `__array_interface__`
Expand All @@ -40,6 +49,7 @@
## [1.2.2] - 2024-07-04

- Support for NumPy >= 2.0
- Fix hashing of replay functions and snippets
- Add additional documentation and examples demonstrating `wp.copy()`, `wp.clone()`, and `array.assign()` differentiability
- Fix adding `__new__()` methods for all class `__del__()` methods to
anticipate when a class instance is created but not instantiated before garbage collection.
Expand Down
107 changes: 107 additions & 0 deletions docs/modules/differentiability.rst
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,113 @@ In the example above we can see that the array ``c`` does not have its ``require
.. note::
Arrays can be labeled with custom names using the ``array_labels`` argument to the ``tape.visualize()`` method.

Array Overwrite Tracking
^^^^^^^^^^^^^^^^^^^^^^^^^

It is a common mistake to inadvertently overwrite an array that participates in the computation graph. For example::

with tape as wp.Tape():

# step 1
wp.launch(compute_forces, dim=n, inputs=[pos0, vel0], outputs=[force])
wp.launch(simulate, dim=n, inputs=[pos0, vel0, force], outputs=[pos1, vel1])

# step 2 (error, we are overwriting previous forces)
wp.launch(compute_forces, dim=n, inputs=[pos1, vel1], outputs=[force])
wp.launch(simulate, dim=n, inputs=[pos1, vel1, force], outputs=[pos2, vel2])

# compute loss
wp.launch(loss, dim=n, inputs=[pos2])

tape.backward(loss)

Running the tape backwards will incorrectly compute the gradient of the loss with respect to ``pos0`` and ``vel0``, because ``force`` is overwritten in the second simulation step.
The adjoint of ``force`` with respect to ``pos1`` and ``vel1`` will be correct, because the stored value of ``force`` from the forward pass is still correct, but the adjoint of
``force`` with respect to ``pos0`` and ``vel0`` will be incorrect, because the ``force`` value used in this calculation was calculated in step 2, not step 1. The solution is to allocate
two force arrays, ``force0`` and ``force1``, so that we are not overwriting data that participates in the computation graph.

This sort of problem boils down to a single pattern to be avoided: writing to an array after reading from it. This typically happens over consecutive kernel launches (A), but it might also happen within a single kernel (B).

A: Inter-Kernel Overwrite::

import warp as wp

@wp.kernel
def square_kernel(x: wp.array(dtype=float), y: wp.array(dtype=float)):
tid = wp.tid()
y[tid] = x[tid] * x[tid]

@wp.kernel
def overwrite_kernel(z: wp.array(dtype=float), x: wp.array(dtype=float)):
tid = wp.tid()
x[tid] = z[tid]

@wp.kernel
def loss_kernel(x: wp.array(dtype=float), loss: wp.array(dtype=float)):
tid = wp.tid()
wp.atomic_add(loss, 0, x[tid])

a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True)
b = wp.zeros_like(a)
c = wp.array(np.array([-1.0, -2.0, -3.0]), dtype=float, requires_grad=True)
loss = wp.zeros(1, dtype=float, requires_grad=True)

tape = wp.Tape()
with tape:
wp.launch(square_kernel, a.shape, inputs=[a], outputs=[b])
wp.launch(overwrite_kernel, c.shape, inputs=[c], outputs=[a])
wp.launch(loss_kernel, a.shape, inputs=[a, loss])

tape.backward(loss)

print(a.grad)
# prints [-2. -4. -6.] instead of [2. 4. 6.]

B: Intra-Kernel Overwrite::

import warp as wp

@wp.kernel
def readwrite_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
tid = wp.tid()
b[tid] = a[tid] * a[tid]
a[tid] = 1.0

@wp.kernel
def loss_kernel(x: wp.array(dtype=float), loss: wp.array(dtype=float)):
tid = wp.tid()
wp.atomic_add(loss, 0, x[tid])

a = wp.array(np.array([1.0, 2.0, 3.0]), dtype=float, requires_grad=True)
b = wp.zeros_like(a)
loss = wp.zeros(1, dtype=float, requires_grad=True)

tape = wp.Tape()
with tape:
wp.launch(readwrite_kernel, dim=a.shape, inputs=[a, b])
wp.launch(loss_kernel, a.shape, inputs=[b, loss])

tape.backward(loss)

print(a.grad)
# prints [2. 2. 2.] instead of [2. 4. 6.]

If ``wp.config.verify_autograd_array_access = True`` is set, Warp will automatically detect and report array overwrites, covering the above two cases as well as other problematic configurations.
It does so by flagging which kernel array arguments are read from and/or written to in each kernel function during compilation. At runtime, if an array is passed to a kernel argument marked with a read flag,
it is marked as having been read from. Later, if the same array is passed to a kernel argument marked with a write flag, a warning is printed
(recall the pattern we wish to avoid: *write* after *read*).

.. note::
Setting ``wp.config.verify_autograd_array_access = True`` will disable kernel caching and force the current module to rebuild.

.. note::
Though in-place operations such as ``x[tid] += 1.0`` are technically ``read -> write``, the Warp graph specifically accomodates adjoint accumulation in these cases, so we mark them as write operations.

.. note::
This feature does not yet support arrays packed in Warp structs.

If you make use of :py:meth:`Tape.record_func` in your graph (and so provide your own adjoint callback), be sure to also call :py:meth:`array.mark_write()` and :py:meth:`array.mark_read()`, which will manually mark your arrays as having been written to or read from.

.. _limitations_and_workarounds:

Limitations and Workarounds
Expand Down
98 changes: 97 additions & 1 deletion warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,6 +582,14 @@ def __init__(self, label, type, requires_grad=False, constant=None, prefix=True)
self.constant = constant
self.prefix = prefix

# records whether this Var has been read from in a kernel function (array only)
self.is_read = False
# records whether this Var has been written to in a kernel function (array only)
self.is_write = False

# used to associate a view array Var with its parent array Var
self.parent = None

def __str__(self):
return self.label

Expand Down Expand Up @@ -624,6 +632,42 @@ def emit(self, prefix: str = "var"):
def emit_adj(self):
return self.emit("adj")

def mark_read(self):
"""Marks this Var as having been read from in a kernel (array only)."""
if not is_array(self.type):
return

self.is_read = True

# recursively update all parent states
parent = self.parent
while parent is not None:
parent.is_read = True
parent = parent.parent

def mark_write(self, **kwargs):
"""Marks this Var has having been written to in a kernel (array only)."""
if not is_array(self.type):
return

# detect if we are writing to an array after reading from it within the same kernel
if self.is_read and warp.config.verify_autograd_array_access:
if "kernel_name" and "filename" and "lineno" in kwargs:
print(
f"Warning: Array passed to argument {self.label} in kernel {kwargs['kernel_name']} at {kwargs['filename']}:{kwargs['lineno']} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
)
else:
print(
f"Warning: Array {self} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
)
self.is_write = True

# recursively update all parent states
parent = self.parent
while parent is not None:
parent.is_write = True
parent = parent.parent


class Block:
# Represents a basic block of instructions, e.g.: list
Expand Down Expand Up @@ -821,6 +865,11 @@ def __init__(

# generate function ssa form and adjoint
def build(adj, builder, default_builder_options=None):
# arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
for arg in adj.args:
arg.is_read = False
arg.is_write = False

if adj.skip_build:
return

Expand Down Expand Up @@ -1702,7 +1751,24 @@ def emit_Constant(adj, node):

def emit_BinOp(adj, node):
# evaluate binary operator arguments

if warp.config.verify_autograd_array_access:
# array overwrite tracking: in-place operators are a special case
# x[tid] = x[tid] + 1 is a read followed by a write, but we only want to record the write
# so we save the current arg read flags and restore them after lhs eval
is_read_states = []
for arg in adj.args:
is_read_states.append(arg.is_read)

# evaluate lhs binary operator argument
left = adj.eval(node.left)

if warp.config.verify_autograd_array_access:
# restore arg read flags
for i, arg in enumerate(adj.args):
arg.is_read = is_read_states[i]

# evaluate rhs binary operator argument
right = adj.eval(node.right)

name = builtin_operators[type(node.op)]
Expand Down Expand Up @@ -2017,7 +2083,18 @@ def emit_Call(adj, node):
args = tuple(adj.resolve_arg(x) for x in node.args)
kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}

# add var with value type from the function
if warp.config.verify_autograd_array_access:
# update arg read/write states according to what happens to that arg in the called function
if hasattr(func, "adj"):
for i, arg in enumerate(args):
if func.adj.args[i].is_write:
kernel_name = adj.fun_name
filename = adj.filename
lineno = adj.lineno + adj.fun_lineno
arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
if func.adj.args[i].is_read:
arg.mark_read()

out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
return out

Expand Down Expand Up @@ -2097,10 +2174,22 @@ def emit_Subscript(adj, node):
if len(indices) == target_type.ndim:
# handles array loads (where each dimension has an index specified)
out = adj.add_builtin_call("address", [target, *indices])

if warp.config.verify_autograd_array_access:
target.mark_read()

else:
# handles array views (fewer indices than dimensions)
out = adj.add_builtin_call("view", [target, *indices])

if warp.config.verify_autograd_array_access:
# store reference to target Var to propagate downstream read/write state back to root arg Var
out.parent = target

# view arg inherits target Var's read/write states
out.is_read = target.is_read
out.is_write = target.is_write

else:
# handles non-array type indexing, e.g: vec3, mat33, etc
out = adj.add_builtin_call("extract", [target, *indices])
Expand Down Expand Up @@ -2184,6 +2273,13 @@ def emit_Assign(adj, node):
if is_array(target_type):
adj.add_builtin_call("array_store", [target, *indices, rhs])

if warp.config.verify_autograd_array_access:
kernel_name = adj.fun_name
filename = adj.filename
lineno = adj.lineno + adj.fun_lineno

target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)

elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
if is_reference(target.type):
attr = adj.add_builtin_call("indexref", [target, *indices])
Expand Down
3 changes: 3 additions & 0 deletions warp/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
quiet: bool = False
"""Suppress all output except errors and warnings."""

verify_autograd_array_access: bool = False
"""print warnings related to array overwrites that may result in incorrect gradients"""

cache_kernels: bool = True
"""If `True`, kernels that have already been compiled from previous application launches will not be recompiled."""

Expand Down
15 changes: 14 additions & 1 deletion warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1748,7 +1748,13 @@ def load(self, device) -> bool:

build_dir = None

if not os.path.exists(binary_path) or not warp.config.cache_kernels:
# we always want to build if binary doesn't exist yet
# and we want to rebuild if we are not caching kernels or if we are tracking array access
if (
not os.path.exists(binary_path)
or not warp.config.cache_kernels
or warp.config.verify_autograd_array_access
):
builder = ModuleBuilder(self, self.options)

# create a temporary (process unique) dir for build outputs before moving to the binary dir
Expand Down Expand Up @@ -4780,6 +4786,10 @@ def pack_args(args, params, adjoint=False):
caller = {"file": frame.f_code.co_filename, "lineno": frame.f_lineno, "func": frame.f_code.co_name}
runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device, metadata={"caller": caller})

# detect illegal inter-kernel read/write access patterns if verification flag is set
if warp.config.verify_autograd_array_access:
runtime.tape.check_kernel_array_access(kernel, fwd_args)


def synchronize():
"""Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices
Expand Down Expand Up @@ -5261,6 +5271,9 @@ def copy(
),
arrays=[dest, src],
)
if warp.config.verify_autograd_array_access:
dest.mark_write()
src.mark_read()


def adj_copy(
Expand Down
30 changes: 29 additions & 1 deletion warp/tape.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def record_func(self, backward, arrays):
Args:
backward (Callable): A callable Python object (can be any function) that will be executed in the backward pass.
arrays (list): A list of arrays that are used by the function for gradient tracking.
arrays (list): A list of arrays that are used by the backward function. The tape keeps track of these to be able to zero their gradients in Tape.zero()
"""
self.launches.append(backward)

Expand Down Expand Up @@ -197,6 +197,24 @@ def record_scope_end(self, remove_scope_if_empty=True):
else:
self.scopes.append((len(self.launches), None, None))

def check_kernel_array_access(self, kernel, args):
"""Detect illegal inter-kernel write after read access patterns during launch capture"""
adj = kernel.adj
kernel_name = adj.fun_name
filename = adj.filename
lineno = adj.fun_lineno

for i, arg in enumerate(args):
if isinstance(arg, wp.array):
arg_name = adj.args[i].label

# we check write condition first because we allow (write --> read) within the same kernel
if adj.args[i].is_write:
arg.mark_write(arg_name=arg_name, kernel_name=kernel_name, filename=filename, lineno=lineno)

if adj.args[i].is_read:
arg.mark_read()

# returns the adjoint of a kernel parameter
def get_adjoint(self, a):
if not wp.types.is_array(a) and not isinstance(a, wp.codegen.StructInstance):
Expand Down Expand Up @@ -237,6 +255,8 @@ def reset(self):
self.launches = []
self.scopes = []
self.zero()
if wp.config.verify_autograd_array_access:
self.reset_array_read_flags()

def zero(self):
"""
Expand All @@ -251,6 +271,14 @@ def zero(self):
else:
g.zero_()

def reset_array_read_flags(self):
"""
Reset all recorded array read flags to False
"""
for a in self.gradients:
if isinstance(a, wp.array):
a.mark_init()

def visualize(
self,
filename: str = None,
Expand Down
Loading

0 comments on commit c1915c9

Please sign in to comment.