Skip to content

Commit

Permalink
Merge branch 'featherstone-grad-fix' into 'main'
Browse files Browse the repository at this point in the history
Fix gradient for mass matrix in Featherstone

See merge request omniverse/warp!543
  • Loading branch information
mmacklin committed Jun 6, 2024
2 parents 3fc50e3 + d28e85c commit 7829243
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions warp/sim/integrator_featherstone.py
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,7 @@ def dense_solve(
n: int,
L_start: int,
b_start: int,
A: wp.array(dtype=float),
L: wp.array(dtype=float),
b: wp.array(dtype=float),
# outputs
Expand All @@ -1303,13 +1304,14 @@ def adj_dense_solve(
n: int,
L_start: int,
b_start: int,
A: wp.array(dtype=float),
L: wp.array(dtype=float),
b: wp.array(dtype=float),
# outputs
x: wp.array(dtype=float),
tmp: wp.array(dtype=float),
):
if not tmp or not wp.adjoint[x] or not wp.adjoint[L]:
if not tmp or not wp.adjoint[x] or not wp.adjoint[A] or not wp.adjoint[L]:
return
for i in range(n):
tmp[b_start + i] = 0.0
Expand All @@ -1324,12 +1326,17 @@ def adj_dense_solve(
for j in range(n):
wp.adjoint[L][L_start + dense_index(n, i, j)] += -tmp[b_start + i] * x[b_start + j]

for i in range(n):
for j in range(n):
wp.adjoint[A][L_start + dense_index(n, i, j)] += -tmp[b_start + i] * x[b_start + j]


@wp.kernel
def eval_dense_solve_batched(
L_start: wp.array(dtype=int),
L_dim: wp.array(dtype=int),
b_start: wp.array(dtype=int),
A: wp.array(dtype=float),
L: wp.array(dtype=float),
b: wp.array(dtype=float),
# outputs
Expand All @@ -1338,7 +1345,7 @@ def eval_dense_solve_batched(
):
batch = wp.tid()

dense_solve(L_dim[batch], L_start[batch], b_start[batch], L, b, x, tmp)
dense_solve(L_dim[batch], L_start[batch], b_start[batch], A, L, b, x, tmp)


@wp.kernel
Expand Down Expand Up @@ -1509,7 +1516,6 @@ def allocate_model_aux_vars(self, model):
self.L = wp.zeros_like(self.H)

if model.body_count:
# TODO use requires_grad here?
self.body_I_m = wp.empty(
(model.body_count,), dtype=wp.spatial_matrix, device=model.device, requires_grad=model.requires_grad
)
Expand Down Expand Up @@ -1859,6 +1865,7 @@ def simulate(self, model: Model, state_in: State, state_out: State, dt: float, c
self.articulation_H_start,
self.articulation_H_rows,
self.articulation_dof_start,
self.H,
self.L,
state_aug.joint_tau,
],
Expand Down

0 comments on commit 7829243

Please sign in to comment.