Skip to content

Commit

Permalink
Merge branch 'ershi/add-module-lineinfo' into 'main'
Browse files Browse the repository at this point in the history
Add per-module option "lineinfo" to add CUDA-C line information (GH-425, GH-431)

See merge request omniverse/warp!979
  • Loading branch information
mmacklin committed Jan 14, 2025
2 parents 1d8aba6 + 8dbedf8 commit b45954b
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 14 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,13 @@
- Add optimization example for soft-body properties ([GH-419](https://github.com/NVIDIA/warp/pull/419)).
- Add per-module option to disable fused floating point operations, use `wp.set_module_options({"fuse_fp": False})`
([GH-379](https://github.com/NVIDIA/warp/issues/379)).
- Add per-module option to add CUDA-C line information for profiling, use `wp.set_module_options({"lineinfo": True})`.

### Changed

- Files in the kernel cache will be named according to their directory. Previously, all files began with
`module_codegen` ([GH-431](https://github.com/NVIDIA/warp/issues/431)).

### Fixed

- Fix errors during graph capture caused by module unloading ([GH-401](https://github.com/NVIDIA/warp/issues/401)).
Expand Down
7 changes: 7 additions & 0 deletions docs/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ The options for a module can also be queried using ``wp.get_module_options()``.
| | | | that functionally equivalent kernels will produce identical results |
| | | | unaffected by the presence or absence of fused operations. |
+--------------------+---------+-------------+--------------------------------------------------------------------------+
|``lineinfo`` | Boolean | ``False`` | If ``True``, CUDA kernels will be compiled with the |
| | | | ``--generate-line-info`` compiler option, which generates line-number |
| | | | information for device code, e.g. to allow NVIDIA Nsight Compute to |
| | | | correlate CUDA-C source and SASS. Line-number information is always |
| | | | included when compiling kernels in ``"debug"`` mode regardless of this |
| | | | setting. |
+--------------------+---------+-------------+--------------------------------------------------------------------------+
|``cuda_output`` | String | ``None`` | The preferred CUDA output format for kernels. Valid choices are ``None``,|
| | | | ``"ptx"``, and ``"cubin"``. If ``None``, a format will be determined |
| | | | automatically. The module-level setting takes precedence over the global |
Expand Down
13 changes: 11 additions & 2 deletions warp/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,16 @@

# builds cuda source to PTX or CUBIN using NVRTC (output type determined by output_path extension)
def build_cuda(
cu_path, arch, output_path, config="release", verify_fp=False, fast_math=False, fuse_fp=True, ltoirs=None
):
cu_path,
arch,
output_path,
config="release",
verify_fp=False,
fast_math=False,
fuse_fp=True,
lineinfo=False,
ltoirs=None,
) -> None:
with open(cu_path, "rb") as src_file:
src = src_file.read()
cu_path_bytes = cu_path.encode("utf-8")
Expand Down Expand Up @@ -45,6 +53,7 @@ def build_cuda(
verify_fp,
fast_math,
fuse_fp,
lineinfo,
output_path,
num_ltoirs,
arr_lroirs,
Expand Down
22 changes: 13 additions & 9 deletions warp/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,7 @@ def __init__(self, name, loader):
"enable_backward": warp.config.enable_backward,
"fast_math": False,
"fuse_fp": True,
"lineinfo": False,
"cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
"mode": warp.config.mode,
"block_dim": 256,
Expand Down Expand Up @@ -1998,15 +1999,16 @@ def load(self, device, block_dim=None) -> ModuleExec:
module_hash = self.hasher.module_hash

# use a unique module path using the module short hash
module_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}")
module_name_short = f"{module_name}_{module_hash.hex()[:7]}"
module_dir = os.path.join(warp.config.kernel_cache_dir, module_name_short)

with warp.ScopedTimer(
f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet
) as module_load_timer:
# -----------------------------------------------------------
# determine output paths
if device.is_cpu:
output_name = "module_codegen.o"
output_name = f"{module_name_short}.o"
output_arch = None

elif device.is_cuda:
Expand All @@ -2026,10 +2028,10 @@ def load(self, device, block_dim=None) -> ModuleExec:

if use_ptx:
output_arch = min(device.arch, warp.config.ptx_target_arch)
output_name = f"module_codegen.sm{output_arch}.ptx"
output_name = f"{module_name_short}.sm{output_arch}.ptx"
else:
output_arch = device.arch
output_name = f"module_codegen.sm{output_arch}.cubin"
output_name = f"{module_name_short}.sm{output_arch}.cubin"

# final object binary path
binary_path = os.path.join(module_dir, output_name)
Expand Down Expand Up @@ -2067,7 +2069,7 @@ def load(self, device, block_dim=None) -> ModuleExec:
if device.is_cpu:
# build
try:
source_code_path = os.path.join(build_dir, "module_codegen.cpp")
source_code_path = os.path.join(build_dir, f"{module_name_short}.cpp")

# write cpp sources
cpp_source = builder.codegen("cpu")
Expand Down Expand Up @@ -2096,7 +2098,7 @@ def load(self, device, block_dim=None) -> ModuleExec:
elif device.is_cuda:
# build
try:
source_code_path = os.path.join(build_dir, "module_codegen.cu")
source_code_path = os.path.join(build_dir, f"{module_name_short}.cu")

# write cuda sources
cu_source = builder.codegen("cuda")
Expand All @@ -2113,9 +2115,10 @@ def load(self, device, block_dim=None) -> ModuleExec:
output_arch,
output_path,
config=self.options["mode"],
verify_fp=warp.config.verify_fp,
fast_math=self.options["fast_math"],
fuse_fp=self.options["fuse_fp"],
verify_fp=warp.config.verify_fp,
lineinfo=self.options["lineinfo"],
ltoirs=builder.ltoirs.values(),
)

Expand All @@ -2128,7 +2131,7 @@ def load(self, device, block_dim=None) -> ModuleExec:
# build meta data

meta = builder.build_meta()
meta_path = os.path.join(build_dir, "module_codegen.meta")
meta_path = os.path.join(build_dir, f"{module_name_short}.meta")

with open(meta_path, "w") as meta_file:
json.dump(meta, meta_file)
Expand Down Expand Up @@ -2192,7 +2195,7 @@ def safe_rename(src, dst, attempts=5, delay=0.1):
# -----------------------------------------------------------
# Load CPU or CUDA binary

meta_path = os.path.join(module_dir, "module_codegen.meta")
meta_path = os.path.join(module_dir, f"{module_name_short}.meta")
with open(meta_path, "r") as meta_file:
meta = json.load(meta_file)

Expand Down Expand Up @@ -3483,6 +3486,7 @@ def __init__(self):
ctypes.c_bool, # verify_fp
ctypes.c_bool, # fast_math
ctypes.c_bool, # fuse_fp
ctypes.c_bool, # lineinfo
ctypes.c_char_p, # output_path
ctypes.c_size_t, # num_ltoirs
ctypes.POINTER(ctypes.c_char_p), # ltoirs
Expand Down
2 changes: 1 addition & 1 deletion warp/native/warp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ WP_API bool cuda_graph_end_capture(void* context, void* stream, void** graph_ret
WP_API bool cuda_graph_launch(void* graph, void* stream) { return false; }
WP_API bool cuda_graph_destroy(void* context, void* graph) { return false; }

WP_API size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes) { return 0; }
WP_API size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes) { return 0; }

WP_API void* cuda_load_module(void* context, const char* ptx) { return NULL; }
WP_API void cuda_unload_module(void* context, void* module) {}
Expand Down
7 changes: 6 additions & 1 deletion warp/native/warp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2654,7 +2654,7 @@ bool write_file(const char* data, size_t size, std::string filename, const char*
}
#endif

size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes)
size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes)
{
// use file extension to determine whether to output PTX or CUBIN
const char* output_ext = strrchr(output_path, '.');
Expand Down Expand Up @@ -2715,8 +2715,13 @@ size_t cuda_compile_program(const char* cuda_src, const char* program_name, int
//opts.push_back("--device-debug");
}
else
{
opts.push_back("--define-macro=NDEBUG");

if (lineinfo)
opts.push_back("--generate-line-info");
}

if (verify_fp)
opts.push_back("--define-macro=WP_VERIFY_FP");
else
Expand Down
2 changes: 1 addition & 1 deletion warp/native/warp.h
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ extern "C"
WP_API bool cuda_graph_launch(void* graph, void* stream);
WP_API bool cuda_graph_destroy(void* context, void* graph);

WP_API size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes);
WP_API size_t cuda_compile_program(const char* cuda_src, const char* program_name, int arch, const char* include_dir, int num_cuda_include_dirs, const char** cuda_include_dirs, bool debug, bool verbose, bool verify_fp, bool fast_math, bool fuse_fp, bool lineinfo, const char* output_path, size_t num_ltoirs, char** ltoirs, size_t* ltoir_sizes);
WP_API bool cuda_compile_fft(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int size, int elements_per_thread, int direction, int precision, int* shared_memory_size);
WP_API bool cuda_compile_dot(const char* ltoir_output_path, const char* symbol_name, int num_include_dirs, const char** include_dirs, const char* mathdx_include_dir, int arch, int M, int N, int K, int precision_A, int precision_B, int precision_C, int type, int arrangement_A, int arrangement_B, int arrangement_C, int num_threads);

Expand Down

0 comments on commit b45954b

Please sign in to comment.