Skip to content

Commit

Permalink
warp.fem: Fix differentiability of PicQuadrature
Browse files Browse the repository at this point in the history
  • Loading branch information
gdaviet committed Jul 6, 2024
1 parent 7de9179 commit 3b48414
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
- Simplified defining custom subdomains (`wp.fem.Subdomain`), free-slip boundary conditions
- New `streamlines` example, updated `mixed_elasticity` to use a nonlinear model
- Fixed edge cases with Nanovdb function spaces
- Fixed differentiability of `wp.fem.PicQuadrature` w.r.t. positions and measures
- Improve error messages for unsupported constructs

## [1.2.2] - 2024-07-04
Expand Down
23 changes: 21 additions & 2 deletions warp/fem/cache.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import bisect
import re
import weakref
from copy import copy
from typing import Any, Callable, Dict, Optional, Tuple, Union

Expand Down Expand Up @@ -208,6 +209,12 @@ def __init__(self, array: wp.array, pool: Optional["TemporaryStore.Pool"] = None
self._array_view = array
self._pool = pool

if pool is not None and wp.context.runtime.tape is not None:
# Extend lifetime to that of Tape (or Pool if shorter)
# This is to prevent temporary arrays held in tape launch parameters to be redeemed
pool.hold(self)
weakref.finalize(wp.context.runtime.tape, TemporaryStore.Pool.stop_holding, pool, self)

if shape is not None or dtype is not None:
self._view_as(shape=shape, dtype=dtype)

Expand Down Expand Up @@ -275,6 +282,8 @@ def __init__(self, dtype, device, pinned: bool):
self._pool_sizes = [] # Sizes of available arrays for borrowing, ascending
self._allocs = {} # All allocated arrays, including borrowed ones

self._held_temporaries = set() # Temporaries that are prevented from going out of scope

def borrow(self, shape, dtype, requires_grad: bool):
size = 1
if isinstance(shape, int):
Expand All @@ -290,8 +299,12 @@ def borrow(self, shape, dtype, requires_grad: bool):
# Big enough array found, remove from pool
array = self._pool.pop(index)
self._pool_sizes.pop(index)
if requires_grad and array.grad is None:
array.requires_grad = True
if requires_grad:
if array.grad is None:
array.requires_grad = True
else:
# Zero-out existing gradient to mimic semantics of wp.empty()
array.grad.zero_()
return Temporary(pool=self, array=array, shape=shape, dtype=dtype)

# No big enough array found, allocate new one
Expand All @@ -317,6 +330,12 @@ def redeem(self, array):
def detach(self, array):
del self._allocs[array.ptr]

def hold(self, temp: Temporary):
self._held_temporaries.add(temp)

def stop_holding(self, temp: Temporary):
self._held_temporaries.remove(temp)

def __init__(self):
self.clear()

Expand Down
9 changes: 5 additions & 4 deletions warp/fem/quadrature/pic_quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from .quadrature import Quadrature

wp.set_module_options({"enable_backward": False})


class PicQuadrature(Quadrature):
"""Particle-based quadrature formula, using a global set of points unevenly spread out over geometry elements.
Expand All @@ -23,6 +21,7 @@ class PicQuadrature(Quadrature):
define a global :meth:`Geometry.cell_lookup` method; currently this is only available for :class:`Grid2D` and :class:`Grid3D`.
measures: Array containing the measure (area/volume) of each particle, used to defined the integration weights.
If ``None``, defaults to the cell measure divided by the number of particles in the cell.
requires_grad: Whether gradients should be allocated for the computed quantities
temporary_store: shared pool from which to allocate temporary arrays
"""

Expand All @@ -37,10 +36,12 @@ def __init__(
],
],
measures: Optional["wp.array(dtype=float)"] = None,
requires_grad: bool = False,
temporary_store: TemporaryStore = None,
):
super().__init__(domain)

self._requires_grad = requires_grad
self._bin_particles(positions, measures, temporary_store)
self._max_particles_per_cell: int = None

Expand Down Expand Up @@ -177,7 +178,7 @@ def bin_particles(
cell_index = cell_index_temp.array

self._particle_coords_temp = borrow_temporary(
temporary_store, shape=positions.shape, dtype=Coords, device=device
temporary_store, shape=positions.shape, dtype=Coords, device=device, requires_grad=self._requires_grad
)
self._particle_coords = self._particle_coords_temp.array

Expand Down Expand Up @@ -211,7 +212,7 @@ def _compute_fraction(self, cell_index, measures, temporary_store: TemporaryStor
device = cell_index.device

self._particle_fraction_temp = borrow_temporary(
temporary_store, shape=cell_index.shape, dtype=float, device=device
temporary_store, shape=cell_index.shape, dtype=float, device=device, requires_grad=self._requires_grad
)
self._particle_fraction = self._particle_fraction_temp.array

Expand Down
15 changes: 15 additions & 0 deletions warp/tests/test_fem.py
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,21 @@ def test_particle_quadratures(test, device):
val = fem.integrate(_piecewise_constant, quadrature=pic_quadrature)
test.assertAlmostEqual(val, 1.25, places=5)

# Test differentiability of PicQuadrature w.r.t positions and measures
points = wp.array([[0.25, 0.33], [0.33, 0.25], [0.8, 0.8]], dtype=wp.vec2, device=device, requires_grad=True)
measures = wp.ones(3, dtype=float, device=device, requires_grad=True)

tape = wp.Tape()
with tape:
pic = fem.PicQuadrature(domain, positions=points, measures=measures, requires_grad=True)

pic.arg_value(device).particle_coords.grad.fill_(1.0)
pic.arg_value(device).particle_fraction.grad.fill_(1.0)
tape.backward()

assert_np_equal(points.grad.numpy(), np.full((3, 2), 2.0)) # == 1.0 / cell_size
assert_np_equal(measures.grad.numpy(), np.full(3, 4.0)) # == 1.0 / cell_area


@wp.kernel
def test_qr_eigenvalues():
Expand Down

0 comments on commit 3b48414

Please sign in to comment.