From d3cab4273fc3e94da963a59bd652860e35c4569c Mon Sep 17 00:00:00 2001 From: Zach Corse Date: Wed, 15 Jan 2025 10:46:55 -0800 Subject: [PATCH] Maintain one module hash/exec per block_dim used --- CHANGELOG.md | 1 + warp/context.py | 39 +++++++++++++++++++-------------------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 74507aec8..e96c875e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ `module_codegen` ([GH-431](https://github.com/NVIDIA/warp/issues/431)). - Emit deprecation warnings for the use of the `owner` and `length` keywords in the `wp.array` initializer. +- Avoid recompilation of modules when changing `block_dim`. ### Fixed diff --git a/warp/context.py b/warp/context.py index 4d2d215ed..a15c5ada7 100644 --- a/warp/context.py +++ b/warp/context.py @@ -1803,13 +1803,13 @@ def __init__(self, name, loader): self._live_kernels = weakref.WeakSet() # executable modules currently loaded - self.execs = {} # (device.context: ModuleExec) + self.execs = {} # ((device.context, blockdim): ModuleExec) # set of device contexts where the build has failed self.failed_builds = set() - # hash data, including the module hash - self.hasher = None + # hash data, including the module hash. Module may store multiple hashes (one per block_dim used) + self.hashers = {} # LLVM executable modules are identified using strings. Since it's possible for multiple # executable versions to be loaded at the same time, we need a way to ensure uniqueness. @@ -1967,28 +1967,27 @@ def add_ref(ref): def hash_module(self): # compute latest hash - self.hasher = ModuleHasher(self) - return self.hasher.get_module_hash() + block_dim = self.options["block_dim"] + self.hashers[block_dim] = ModuleHasher(self) + return self.hashers[block_dim].get_module_hash() def load(self, device, block_dim=None) -> ModuleExec: device = runtime.get_device(device) - # re-compile module if tile size (blockdim) changes - # todo: it would be better to have a method such as `module.get_kernel(block_dim=N)` - # that can return a single kernel instance with a given block size + # update module options if launching with a new block dim if block_dim is not None: - if self.options["block_dim"] != block_dim: - self.unload() self.options["block_dim"] = block_dim + active_block_dim = self.options["block_dim"] + # compute the hash if needed - if self.hasher is None: - self.hasher = ModuleHasher(self) + if active_block_dim not in self.hashers: + self.hashers[active_block_dim] = ModuleHasher(self) # check if executable module is already loaded and not stale - exec = self.execs.get(device.context) + exec = self.execs.get((device.context, active_block_dim)) if exec is not None: - if exec.module_hash == self.hasher.module_hash: + if exec.module_hash == self.hashers[active_block_dim].get_module_hash(): return exec # quietly avoid repeated build attempts to reduce error spew @@ -1996,7 +1995,7 @@ def load(self, device, block_dim=None) -> ModuleExec: return None module_name = "wp_" + self.name - module_hash = self.hasher.module_hash + module_hash = self.hashers[active_block_dim].get_module_hash() # use a unique module path using the module short hash module_name_short = f"{module_name}_{module_hash.hex()[:7]}" @@ -2053,7 +2052,7 @@ def load(self, device, block_dim=None) -> ModuleExec: # Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch "output_arch": output_arch, } - builder = ModuleBuilder(self, builder_options, hasher=self.hasher) + builder = ModuleBuilder(self, builder_options, hasher=self.hashers[active_block_dim]) # create a temporary (process unique) dir for build outputs before moving to the binary dir build_dir = os.path.join( @@ -2205,13 +2204,13 @@ def safe_rename(src, dst, attempts=5, delay=0.1): self.cpu_exec_id += 1 runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8")) module_exec = ModuleExec(module_handle, module_hash, device, meta) - self.execs[None] = module_exec + self.execs[(None, active_block_dim)] = module_exec elif device.is_cuda: cuda_module = warp.build.load_cuda(binary_path, device) if cuda_module is not None: module_exec = ModuleExec(cuda_module, module_hash, device, meta) - self.execs[device.context] = module_exec + self.execs[(device.context, active_block_dim)] = module_exec else: module_load_timer.extra_msg = " (error)" raise Exception(f"Failed to load CUDA module '{self.name}'") @@ -2233,14 +2232,14 @@ def unload(self): def mark_modified(self): # clear hash data - self.hasher = None + self.hashers = {} # clear build failures self.failed_builds = set() # lookup kernel entry points based on name, called after compilation / module load def get_kernel_hooks(self, kernel, device): - module_exec = self.execs.get(device.context) + module_exec = self.execs.get((device.context, self.options["block_dim"])) if module_exec is not None: return module_exec.get_kernel_hooks(kernel) else: