From 429c28139f5d593e4b5c75f2dc976e753f49dde9 Mon Sep 17 00:00:00 2001 From: Eric Shi Date: Thu, 19 Dec 2024 16:33:21 -0800 Subject: [PATCH] Fix various issues with Tile docs --- docs/modules/functions.rst | 47 +++++++++++++++++++------------------ warp/builtins.py | 48 ++++++++++++++++++++------------------ warp/native/tile.h | 7 ++---- warp/stubs.py | 46 ++++++++++++++++++------------------ 4 files changed, 74 insertions(+), 74 deletions(-) diff --git a/docs/modules/functions.rst b/docs/modules/functions.rst index 32ed34bc2..a27c83e50 100644 --- a/docs/modules/functions.rst +++ b/docs/modules/functions.rst @@ -876,7 +876,7 @@ Tile Primitives :returns: A tile with ``shape=(m,n)`` and dtype the same as the source array -.. py:function:: tile_store(a: Array[Any], i: int32, t: Any) -> None +.. py:function:: tile_store(a: Array[Any], i: int32, t: Tile) -> None Stores a 1D tile to a global memory array. @@ -887,7 +887,7 @@ Tile Primitives :param t: The source tile to store data from, must have the same dtype as the destination array -.. py:function:: tile_store(a: Array[Any], i: int32, j: int32, t: Any) -> None +.. py:function:: tile_store(a: Array[Any], i: int32, j: int32, t: Tile) -> None :noindex: :nocontentsentry: @@ -901,7 +901,7 @@ Tile Primitives :param t: The source tile to store data from, must have the same dtype as the destination array -.. py:function:: tile_atomic_add(a: Array[Any], x: int32, y: int32, t: Any) -> Tile +.. py:function:: tile_atomic_add(a: Array[Any], x: int32, y: int32, t: Tile) -> Tile Atomically add a tile to the array `a`, each element will be updated atomically. @@ -967,7 +967,7 @@ Tile Primitives -.. py:function:: untile(a: Any) -> Scalar +.. py:function:: untile(a: Tile) -> Scalar Convert a Tile back to per-thread values. @@ -991,7 +991,7 @@ Tile Primitives t = wp.tile(i)*2 # convert back to per-thread values - s = wp.untile() + s = wp.untile(t) print(s) @@ -1038,7 +1038,7 @@ Tile Primitives Broadcast a tile. - This method will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules. + This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules. :param a: Tile to broadcast :returns: Tile with broadcast ``shape=(m, n)`` @@ -1061,9 +1061,9 @@ Tile Primitives t = wp.tile_ones(dtype=float, m=16, n=16) s = wp.tile_sum(t) - print(t) + print(s) - wp.launch(compute, dim=[64], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64) Prints: @@ -1088,18 +1088,19 @@ Tile Primitives @wp.kernel def compute(): - t = wp.tile_arange(start=--10, stop=10, dtype=float) + t = wp.tile_arange(64, 128) s = wp.tile_min(t) - print(t) + print(s) - wp.launch(compute, dim=[64], inputs=[]) + + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64) Prints: .. code-block:: text - tile(m=1, n=1, storage=register) = [[-10]] + tile(m=1, n=1, storage=register) = [[64 ]] @@ -1118,23 +1119,23 @@ Tile Primitives @wp.kernel def compute(): - t = wp.tile_arange(start=--10, stop=10, dtype=float) - s = wp.tile_min(t) + t = wp.tile_arange(64, 128) + s = wp.tile_max(t) - print(t) + print(s) - wp.launch(compute, dim=[64], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64) Prints: .. code-block:: text - tile(m=1, n=1, storage=register) = [[10]] + tile(m=1, n=1, storage=register) = [[127 ]] -.. py:function:: tile_reduce(op: Callable, a: Any) -> Tile +.. py:function:: tile_reduce(op: Callable, a: Tile) -> Tile Apply a custom reduction operator across the tile. @@ -1156,7 +1157,7 @@ Tile Primitives print(s) - wp.launch(factorial, dim=[16], inputs=[], block_dim=16) + wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16) Prints: @@ -1166,7 +1167,7 @@ Tile Primitives -.. py:function:: tile_map(op: Callable, a: Any) -> Tile +.. py:function:: tile_map(op: Callable, a: Tile) -> Tile Apply a unary function onto the tile. @@ -1188,7 +1189,7 @@ Tile Primitives print(s) - wp.launch(compute, dim=[16], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16) Prints: @@ -1198,7 +1199,7 @@ Tile Primitives -.. py:function:: tile_map(op: Callable, a: Any, b: Any) -> Tile +.. py:function:: tile_map(op: Callable, a: Tile, b: Tile) -> Tile :noindex: :nocontentsentry: @@ -1226,7 +1227,7 @@ Tile Primitives print(s) - wp.launch(compute, dim=[16], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16) Prints: diff --git a/warp/builtins.py b/warp/builtins.py index d3a89dbe2..73a9261a9 100644 --- a/warp/builtins.py +++ b/warp/builtins.py @@ -1852,6 +1852,7 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st step = args[2] if start is None or stop is None or step is None: + print(args) raise RuntimeError("wp.tile_arange() arguments must be compile time constants") if "dtype" in arg_values: @@ -2083,7 +2084,7 @@ def tile_store_1d_value_func(arg_types, arg_values): add_builtin( "tile_store", - input_types={"a": array(dtype=Any), "i": int, "t": Any}, + input_types={"a": array(dtype=Any), "i": int, "t": Tile(dtype=Any, M=Any, N=Any)}, value_func=tile_store_1d_value_func, variadic=False, skip_replay=True, @@ -2132,7 +2133,7 @@ def tile_store_2d_value_func(arg_types, arg_values): add_builtin( "tile_store", - input_types={"a": array(dtype=Any), "i": int, "j": int, "t": Any}, + input_types={"a": array(dtype=Any), "i": int, "j": int, "t": Tile(dtype=Any, M=Any, N=Any)}, value_func=tile_store_2d_value_func, variadic=False, skip_replay=True, @@ -2177,7 +2178,7 @@ def tile_atomic_add_value_func(arg_types, arg_values): add_builtin( "tile_atomic_add", - input_types={"a": array(dtype=Any), "x": int, "y": int, "t": Any}, + input_types={"a": array(dtype=Any), "x": int, "y": int, "t": Tile(dtype=Any, M=Any, N=Any)}, value_func=tile_atomic_add_value_func, variadic=True, skip_replay=True, @@ -2365,7 +2366,7 @@ def untile_value_func(arg_types, arg_values): add_builtin( "untile", - input_types={"a": Any}, + input_types={"a": Tile(dtype=Any, M=Any, N=Any)}, value_func=untile_value_func, variadic=True, doc="""Convert a Tile back to per-thread values. @@ -2390,7 +2391,7 @@ def compute(): t = wp.tile(i)*2 # convert back to per-thread values - s = wp.untile() + s = wp.untile(t) print(s) @@ -2562,7 +2563,7 @@ def tile_broadcast_dispatch_func(arg_types: Mapping[str, type], return_type: Any variadic=True, doc="""Broadcast a tile. - This method will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules. + This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules. :param a: Tile to broadcast :returns: Tile with broadcast ``shape=(m, n)``""", @@ -2654,9 +2655,9 @@ def compute(): t = wp.tile_ones(dtype=float, m=16, n=16) s = wp.tile_sum(t) - print(t) + print(s) - wp.launch(compute, dim=[64], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64) Prints: @@ -2703,18 +2704,19 @@ def tile_min_value_func(arg_types, arg_values): @wp.kernel def compute(): - t = wp.tile_arange(start=--10, stop=10, dtype=float) + t = wp.tile_arange(64, 128) s = wp.tile_min(t) - print(t) + print(s) + - wp.launch(compute, dim=[64], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64) Prints: .. code-block:: text - tile(m=1, n=1, storage=register) = [[-10]] + tile(m=1, n=1, storage=register) = [[64 ]] """, group="Tile Primitives", @@ -2755,18 +2757,18 @@ def tile_max_value_func(arg_types, arg_values): @wp.kernel def compute(): - t = wp.tile_arange(start=--10, stop=10, dtype=float) - s = wp.tile_min(t) + t = wp.tile_arange(64, 128) + s = wp.tile_max(t) - print(t) + print(s) - wp.launch(compute, dim=[64], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64) Prints: .. code-block:: text - tile(m=1, n=1, storage=register) = [[10]] + tile(m=1, n=1, storage=register) = [[127 ]] """, group="Tile Primitives", @@ -2796,7 +2798,7 @@ def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any, add_builtin( "tile_reduce", - input_types={"op": Callable, "a": Any}, + input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any)}, value_func=tile_reduce_value_func, native_func="tile_reduce", doc="""Apply a custom reduction operator across the tile. @@ -2819,7 +2821,7 @@ def factorial(): print(s) - wp.launch(factorial, dim=[16], inputs=[], block_dim=16) + wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16) Prints: @@ -2856,7 +2858,7 @@ def tile_unary_map_value_func(arg_types, arg_values): add_builtin( "tile_map", - input_types={"op": Callable, "a": Any}, + input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any)}, value_func=tile_unary_map_value_func, # dispatch_func=tile_map_dispatch_func, # variadic=True, @@ -2881,7 +2883,7 @@ def compute(): print(s) - wp.launch(compute, dim=[16], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16) Prints: @@ -2923,7 +2925,7 @@ def tile_binary_map_value_func(arg_types, arg_values): add_builtin( "tile_map", - input_types={"op": Callable, "a": Any, "b": Any}, + input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)}, value_func=tile_binary_map_value_func, # dispatch_func=tile_map_dispatch_func, # variadic=True, @@ -2952,7 +2954,7 @@ def compute(): print(s) - wp.launch(compute, dim=[16], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16) Prints: diff --git a/warp/native/tile.h b/warp/native/tile.h index edbc8d665..185f04c28 100644 --- a/warp/native/tile.h +++ b/warp/native/tile.h @@ -1125,8 +1125,6 @@ inline CUDA_CALLABLE auto untile(Tile& tile) } } - - template inline CUDA_CALLABLE void adj_untile(Tile& tile, Tile& adj_tile, Value& adj_ret) { @@ -1156,7 +1154,7 @@ inline CUDA_CALLABLE auto tile_zeros() return T(0); } -// zero initialized tile +// one-initialized tile template inline CUDA_CALLABLE auto tile_ones() { @@ -1164,7 +1162,7 @@ inline CUDA_CALLABLE auto tile_ones() return T(1); } -// zero initialized tile +// tile with evenly spaced values template inline CUDA_CALLABLE auto tile_arange(T start, T stop, T step) { @@ -1220,7 +1218,6 @@ inline CUDA_CALLABLE void tile_store(array_t& dest, int x, int y, Tile& src) src.copy_to_global(dest, x, y); } -// entry point for store template inline CUDA_CALLABLE auto tile_atomic_add(array_t& dest, int x, int y, Tile& src) { diff --git a/warp/stubs.py b/warp/stubs.py index a88eb4444..21ac7a871 100644 --- a/warp/stubs.py +++ b/warp/stubs.py @@ -975,7 +975,7 @@ def tile_load(a: Array[Any], i: int32, j: int32, m: int32, n: int32, storage: st @over -def tile_store(a: Array[Any], i: int32, t: Any): +def tile_store(a: Array[Any], i: int32, t: Tile): """Stores a 1D tile to a global memory array. This method will cooperatively store a tile to global memory using all threads in the block. @@ -988,7 +988,7 @@ def tile_store(a: Array[Any], i: int32, t: Any): @over -def tile_store(a: Array[Any], i: int32, j: int32, t: Any): +def tile_store(a: Array[Any], i: int32, j: int32, t: Tile): """Stores a tile to a global memory array. This method will cooperatively store a tile to global memory using all threads in the block. @@ -1002,7 +1002,7 @@ def tile_store(a: Array[Any], i: int32, j: int32, t: Any): @over -def tile_atomic_add(a: Array[Any], x: int32, y: int32, t: Any) -> Tile: +def tile_atomic_add(a: Array[Any], x: int32, y: int32, t: Tile) -> Tile: """Atomically add a tile to the array `a`, each element will be updated atomically. :param a: Array in global memory, should have the same ``dtype`` as the input tile @@ -1077,7 +1077,7 @@ def compute(): @over -def untile(a: Any) -> Scalar: +def untile(a: Tile) -> Scalar: """Convert a Tile back to per-thread values. This function converts a block-wide tile back to per-thread values. @@ -1100,7 +1100,7 @@ def compute(): t = wp.tile(i) * 2 # convert back to per-thread values - s = wp.untile() + s = wp.untile(t) print(s) @@ -1154,7 +1154,7 @@ def tile_transpose(a: Tile) -> Tile: def tile_broadcast(a: Tile, m: int32, n: int32) -> Tile: """Broadcast a tile. - This method will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules. + This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules. :param a: Tile to broadcast :returns: Tile with broadcast ``shape=(m, n)`` @@ -1178,10 +1178,10 @@ def compute(): t = wp.tile_ones(dtype=float, m=16, n=16) s = wp.tile_sum(t) - print(t) + print(s) - wp.launch(compute, dim=[64], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64) Prints: @@ -1207,19 +1207,19 @@ def tile_min(a: Tile) -> Tile: @wp.kernel def compute(): - t = wp.tile_arange(start=--10, stop=10, dtype=float) + t = wp.tile_arange(64, 128) s = wp.tile_min(t) - print(t) + print(s) - wp.launch(compute, dim=[64], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64) Prints: .. code-block:: text - tile(m=1, n=1, storage=register) = [[-10]] + tile(m=1, n=1, storage=register) = [[64 ]] """ @@ -1239,19 +1239,19 @@ def tile_max(a: Tile) -> Tile: @wp.kernel def compute(): - t = wp.tile_arange(start=--10, stop=10, dtype=float) - s = wp.tile_min(t) + t = wp.tile_arange(64, 128) + s = wp.tile_max(t) - print(t) + print(s) - wp.launch(compute, dim=[64], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64) Prints: .. code-block:: text - tile(m=1, n=1, storage=register) = [[10]] + tile(m=1, n=1, storage=register) = [[127 ]] """ @@ -1259,7 +1259,7 @@ def compute(): @over -def tile_reduce(op: Callable, a: Any) -> Tile: +def tile_reduce(op: Callable, a: Tile) -> Tile: """Apply a custom reduction operator across the tile. This function cooperatively performs a reduction using the provided operator across the tile. @@ -1280,7 +1280,7 @@ def factorial(): print(s) - wp.launch(factorial, dim=[16], inputs=[], block_dim=16) + wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16) Prints: @@ -1293,7 +1293,7 @@ def factorial(): @over -def tile_map(op: Callable, a: Any) -> Tile: +def tile_map(op: Callable, a: Tile) -> Tile: """Apply a unary function onto the tile. This function cooperatively applies a unary function to each element of the tile using all threads in the block. @@ -1314,7 +1314,7 @@ def compute(): print(s) - wp.launch(compute, dim=[16], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16) Prints: @@ -1327,7 +1327,7 @@ def compute(): @over -def tile_map(op: Callable, a: Any, b: Any) -> Tile: +def tile_map(op: Callable, a: Tile, b: Tile) -> Tile: """Apply a binary function onto the tile. This function cooperatively applies a binary function to each element of the tiles using all threads in the block. @@ -1352,7 +1352,7 @@ def compute(): print(s) - wp.launch(compute, dim=[16], inputs=[]) + wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16) Prints: