Skip to content

Commit

Permalink
Merge branch 'block_dim_hash' into 'main'
Browse files Browse the repository at this point in the history
Maintain one module hash/exec per block_dim used

See merge request omniverse/warp!975
  • Loading branch information
daedalus5 committed Jan 15, 2025
2 parents ca0ee84 + d3cab42 commit 23b00ae
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
`module_codegen` ([GH-431](https://github.com/NVIDIA/warp/issues/431)).
- Emit deprecation warnings for the use of the `owner` and `length` keywords in
the `wp.array` initializer.
- Avoid recompilation of modules when changing `block_dim`.

### Fixed

Expand Down
39 changes: 19 additions & 20 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1803,13 +1803,13 @@ def __init__(self, name, loader):
self._live_kernels = weakref.WeakSet()

# executable modules currently loaded
self.execs = {} # (device.context: ModuleExec)
self.execs = {} # ((device.context, blockdim): ModuleExec)

# set of device contexts where the build has failed
self.failed_builds = set()

# hash data, including the module hash
self.hasher = None
# hash data, including the module hash. Module may store multiple hashes (one per block_dim used)
self.hashers = {}

# LLVM executable modules are identified using strings. Since it's possible for multiple
# executable versions to be loaded at the same time, we need a way to ensure uniqueness.
Expand Down Expand Up @@ -1967,36 +1967,35 @@ def add_ref(ref):

def hash_module(self):
# compute latest hash
self.hasher = ModuleHasher(self)
return self.hasher.get_module_hash()
block_dim = self.options["block_dim"]
self.hashers[block_dim] = ModuleHasher(self)
return self.hashers[block_dim].get_module_hash()

def load(self, device, block_dim=None) -> ModuleExec:
device = runtime.get_device(device)

# re-compile module if tile size (blockdim) changes
# todo: it would be better to have a method such as `module.get_kernel(block_dim=N)`
# that can return a single kernel instance with a given block size
# update module options if launching with a new block dim
if block_dim is not None:
if self.options["block_dim"] != block_dim:
self.unload()
self.options["block_dim"] = block_dim

active_block_dim = self.options["block_dim"]

# compute the hash if needed
if self.hasher is None:
self.hasher = ModuleHasher(self)
if active_block_dim not in self.hashers:
self.hashers[active_block_dim] = ModuleHasher(self)

# check if executable module is already loaded and not stale
exec = self.execs.get(device.context)
exec = self.execs.get((device.context, active_block_dim))
if exec is not None:
if exec.module_hash == self.hasher.module_hash:
if exec.module_hash == self.hashers[active_block_dim].get_module_hash():
return exec

# quietly avoid repeated build attempts to reduce error spew
if device.context in self.failed_builds:
return None

module_name = "wp_" + self.name
module_hash = self.hasher.module_hash
module_hash = self.hashers[active_block_dim].get_module_hash()

# use a unique module path using the module short hash
module_name_short = f"{module_name}_{module_hash.hex()[:7]}"
Expand Down Expand Up @@ -2053,7 +2052,7 @@ def load(self, device, block_dim=None) -> ModuleExec:
# Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
"output_arch": output_arch,
}
builder = ModuleBuilder(self, builder_options, hasher=self.hasher)
builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim])

# create a temporary (process unique) dir for build outputs before moving to the binary dir
build_dir = os.path.join(
Expand Down Expand Up @@ -2205,13 +2204,13 @@ def safe_rename(src, dst, attempts=5, delay=0.1):
self.cpu_exec_id += 1
runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
module_exec = ModuleExec(module_handle, module_hash, device, meta)
self.execs[None] = module_exec
self.execs[(None, active_block_dim)] = module_exec

elif device.is_cuda:
cuda_module = warp.build.load_cuda(binary_path, device)
if cuda_module is not None:
module_exec = ModuleExec(cuda_module, module_hash, device, meta)
self.execs[device.context] = module_exec
self.execs[(device.context, active_block_dim)] = module_exec
else:
module_load_timer.extra_msg = " (error)"
raise Exception(f"Failed to load CUDA module '{self.name}'")
Expand All @@ -2233,14 +2232,14 @@ def unload(self):

def mark_modified(self):
# clear hash data
self.hasher = None
self.hashers = {}

# clear build failures
self.failed_builds = set()

# lookup kernel entry points based on name, called after compilation / module load
def get_kernel_hooks(self, kernel, device):
module_exec = self.execs.get(device.context)
module_exec = self.execs.get((device.context, self.options["block_dim"]))
if module_exec is not None:
return module_exec.get_kernel_hooks(kernel)
else:
Expand Down

0 comments on commit 23b00ae

Please sign in to comment.