Skip to content

Commit

Permalink
Updated walker example uses tile API
Browse files Browse the repository at this point in the history
  • Loading branch information
daedalus5 committed Jan 15, 2025
1 parent af3b467 commit e79a945
Show file tree
Hide file tree
Showing 3 changed files with 325 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
- Add optimization example for soft-body properties ([GH-419](https://github.com/NVIDIA/warp/pull/419)).
- Add per-module option to disable fused floating point operations, use `wp.set_module_options({"fuse_fp": False})`
([GH-379](https://github.com/NVIDIA/warp/issues/379)).
- Add `example_tile_walker.py`, which reworks the existing `walker.py` to use Warp's tile API for matrix multiplication.

### Changed

Expand Down
5 changes: 5 additions & 0 deletions warp/examples/optim/example_walker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
# simulated forward in time and then evaluated based on the center of mass
# momentum of the mesh.
#
# This example uses the deprecated wp.matmul() for matrix multiplication,
# which will be removed in a future version. See the updated version of
# this example, example_tile_walker.py, in examples/tile for the new
# approach to GEMMs using Warp's tile API.
#
###########################################################################

import math
Expand Down
319 changes: 319 additions & 0 deletions warp/examples/tile/example_tile_walker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,319 @@
# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved.
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.

###########################################################################
# Example Tile Walker
#
# Trains a tetrahedral mesh quadruped to run. Feeds 8 time-varying input
# phases as inputs into a single layer fully connected network with a tanh
# activation function. Interprets the output of the network as tet
# activations, which are fed into the wp.sim soft mesh model. This is
# simulated forward in time and then evaluated based on the center of mass
# momentum of the mesh.
#
# This example uses the Warp tile API, which as of Warp 1.6 is the
# recommended way to handle matrix multiplication. example_walker.py in
# examples/optim demonstrates the old way of doing matrix multiplication,
# wp.matmul(), which will be deprecated in a future version.
#
###########################################################################

import math
import os

import numpy as np
from pxr import Gf, Usd, UsdGeom

import warp as wp
import warp.examples
import warp.optim
import warp.sim
import warp.sim.render

PHASE_COUNT = 8
PHASE_STEP = wp.constant((2.0 * math.pi) / PHASE_COUNT)
PHASE_FREQ = wp.constant(5.0)
ACTIVATION_STRENGTH = wp.constant(0.3)

TILE_TETS = wp.constant(8)
TILE_THREADS = 64


@wp.kernel
def loss_kernel(com: wp.array(dtype=wp.vec3), loss: wp.array(dtype=float)):
tid = wp.tid()
vx = com[tid][0]
vy = com[tid][1]
vz = com[tid][2]
delta = wp.sqrt(vx * vx) + wp.sqrt(vy * vy) - vz

wp.atomic_add(loss, 0, delta)


@wp.kernel
def com_kernel(velocities: wp.array(dtype=wp.vec3), n: int, com: wp.array(dtype=wp.vec3)):
tid = wp.tid()
v = velocities[tid]
a = v / wp.float32(n)
wp.atomic_add(com, 0, a)


@wp.kernel
def compute_phases(phases: wp.array(dtype=float), sim_time: float):
tid = wp.tid()
phases[tid] = wp.sin(PHASE_FREQ * sim_time + wp.float32(tid) * PHASE_STEP)


@wp.func
def tanh(x: float):
return wp.tanh(x) * ACTIVATION_STRENGTH


@wp.kernel
def network(
phases: wp.array2d(dtype=float), weights: wp.array2d(dtype=float), tet_activations: wp.array2d(dtype=float)
):
# output tile index
i = wp.tid()

# GEMM
p = wp.tile_load(phases, 0, 0, m=PHASE_COUNT, n=1)
w = wp.tile_load(weights, i, 0, m=TILE_TETS, n=PHASE_COUNT)
out = wp.tile_matmul(w, p)

# activation
activations = wp.tile_map(tanh, out)
wp.tile_store(tet_activations, i, 0, activations)


class Example:
def __init__(self, stage_path="example_tile_walker.usd", verbose=False, num_frames=300):
self.verbose = verbose

fps = 60
self.frame_dt = 1.0 / fps
self.num_frames = num_frames

self.sim_substeps = 80
self.sim_dt = self.frame_dt / self.sim_substeps
self.sim_time = 0.0

self.iter = 0
self.train_rate = 0.025

self.phase_count = PHASE_COUNT

self.render_time = 0.0

# bear
asset_stage = Usd.Stage.Open(os.path.join(warp.examples.get_asset_directory(), "bear.usd"))

geom = UsdGeom.Mesh(asset_stage.GetPrimAtPath("/root/bear"))
points = geom.GetPointsAttr().Get()

xform = Gf.Matrix4f(geom.ComputeLocalToWorldTransform(0.0))
for i in range(len(points)):
points[i] = xform.Transform(points[i])

self.points = [wp.vec3(point) for point in points]
self.tet_indices = geom.GetPrim().GetAttribute("tetraIndices").Get()

# sim model
builder = wp.sim.ModelBuilder()
builder.add_soft_mesh(
pos=wp.vec3(0.0, 0.5, 0.0),
rot=wp.quat_identity(),
scale=1.0,
vel=wp.vec3(0.0, 0.0, 0.0),
vertices=self.points,
indices=self.tet_indices,
density=1.0,
k_mu=2000.0,
k_lambda=2000.0,
k_damp=2.0,
tri_ke=0.0,
tri_ka=1e-8,
tri_kd=0.0,
tri_drag=0.0,
tri_lift=0.0,
)

# finalize model
self.model = builder.finalize(requires_grad=True)
self.control = self.model.control()

self.model.soft_contact_ke = 2.0e3
self.model.soft_contact_kd = 0.1
self.model.soft_contact_kf = 10.0
self.model.soft_contact_mu = 0.7

radii = wp.zeros(self.model.particle_count, dtype=float)
radii.fill_(0.05)
self.model.particle_radius = radii
self.model.ground = True

# allocate sim states
self.states = []
for _i in range(self.num_frames * self.sim_substeps + 1):
self.states.append(self.model.state(requires_grad=True))

# initialize the integrator.
self.integrator = wp.sim.SemiImplicitIntegrator()

# model input
self.phases = []
for _i in range(self.num_frames):
self.phases.append(wp.zeros(self.phase_count, dtype=float, requires_grad=True))

# weights matrix for linear network
rng = np.random.default_rng(42)
k = 1.0 / self.phase_count
weights = rng.uniform(-np.sqrt(k), np.sqrt(k), (self.model.tet_count, self.phase_count))
self.weights = wp.array(weights, dtype=float, requires_grad=True)

# tanh activation layer array
self.tet_activations = []
for _i in range(self.num_frames):
self.tet_activations.append(wp.zeros(self.model.tet_count, dtype=float, requires_grad=True))

# optimization
self.loss = wp.zeros(1, dtype=float, requires_grad=True)
self.coms = []
for _i in range(self.num_frames):
self.coms.append(wp.zeros(1, dtype=wp.vec3, requires_grad=True))
self.optimizer = warp.optim.Adam([self.weights.flatten()], lr=self.train_rate)

# rendering
if stage_path:
self.renderer = wp.sim.render.SimRenderer(self.model, stage_path)
else:
self.renderer = None

# capture forward/backward passes
self.use_cuda_graph = wp.get_device().is_cuda
if self.use_cuda_graph:
with wp.ScopedCapture() as capture:
self.tape = wp.Tape()
with self.tape:
for i in range(self.num_frames):
self.forward(i)
self.tape.backward(self.loss)
self.graph = capture.graph

def forward(self, frame):
with wp.ScopedTimer("network", active=self.verbose):
# build sinusoidal input phases
wp.launch(kernel=compute_phases, dim=self.phase_count, inputs=[self.phases[frame], self.sim_time])

# apply linear network with tanh activation
wp.launch_tiled(
kernel=network,
dim=math.ceil(self.model.tet_count / TILE_TETS),
inputs=[self.phases[frame].reshape((self.phase_count, 1)), self.weights],
outputs=[self.tet_activations[frame].reshape((self.model.tet_count, 1))],
block_dim=TILE_THREADS,
)
self.control.tet_activations = self.tet_activations[frame]

with wp.ScopedTimer("simulate", active=self.verbose):
# run simulation loop
for i in range(self.sim_substeps):
self.states[frame * self.sim_substeps + i].clear_forces()
self.integrator.simulate(
self.model,
self.states[frame * self.sim_substeps + i],
self.states[frame * self.sim_substeps + i + 1],
self.sim_dt,
self.control,
)
self.sim_time += self.sim_dt

with wp.ScopedTimer("loss", active=self.verbose):
# compute center of mass velocity
wp.launch(
com_kernel,
dim=self.model.particle_count,
inputs=[
self.states[(frame + 1) * self.sim_substeps].particle_qd,
self.model.particle_count,
self.coms[frame],
],
outputs=[],
)
# compute loss
wp.launch(loss_kernel, dim=1, inputs=[self.coms[frame], self.loss], outputs=[])

def step(self):
with wp.ScopedTimer("step"):
if self.use_cuda_graph:
wp.capture_launch(self.graph)
else:
self.tape = wp.Tape()
with self.tape:
for i in range(self.num_frames):
self.forward(i)
self.tape.backward(self.loss)

# optimization
x = self.weights.grad.flatten()
self.optimizer.step([x])

loss = self.loss.numpy()
if self.verbose:
print(f"Iteration {self.iter}: {loss}")

# reset sim
self.sim_time = 0.0
self.states[0] = self.model.state(requires_grad=True)

# clear grads and zero arrays for next iteration
self.tape.zero()
self.loss.zero_()
for i in range(self.num_frames):
self.coms[i].zero_()

self.iter += 1

def render(self):
if self.renderer is None:
return

with wp.ScopedTimer("render"):
for i in range(self.num_frames + 1):
self.renderer.begin_frame(self.render_time)
self.renderer.render(self.states[i * self.sim_substeps])
self.renderer.end_frame()

self.render_time += self.frame_dt


if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--device", type=str, default=None, help="Override the default Warp device.")
parser.add_argument(
"--stage_path",
type=lambda x: None if x == "None" else str(x),
default="example_tile_walker.usd",
help="Path to the output USD file.",
)
parser.add_argument("--num_frames", type=int, default=300, help="Total number of frames per training iteration.")
parser.add_argument("--train_iters", type=int, default=30, help="Total number of training iterations.")
parser.add_argument("--verbose", action="store_true", help="Print out additional status messages during execution.")

args = parser.parse_known_args()[0]

with wp.ScopedDevice(args.device):
example = Example(stage_path=args.stage_path, verbose=args.verbose, num_frames=args.num_frames)

for _ in range(args.train_iters):
example.step()
example.render()

if example.renderer:
example.renderer.save()

0 comments on commit e79a945

Please sign in to comment.