Skip to content

Commit

Permalink
Fix handling of bool arguments in generic kernels and functions
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-heiden committed Jun 5, 2024
1 parent 659d00f commit f3780ee
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
27 changes: 27 additions & 0 deletions warp/tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,32 @@ def test_generic_fill_overloads(test, device):
assert_np_equal(a3b.numpy(), np.full((n, 3), True, dtype=np.bool_))


# generic kernel used to test generic types mixed with specialized types
@wp.func
def generic_conditional_setter_func(a: wp.array(dtype=Any), i: int, value: Any, relative: bool):
if relative:
a[i] += value
else:
a[i] = value


@wp.kernel
def generic_conditional_setter(a: wp.array(dtype=Any), i: int, value: Any, relative: bool):
generic_conditional_setter_func(a, i, value, relative)


def test_generic_conditional_setter(test, device):
with wp.ScopedDevice(device):
n = 10
ai = wp.zeros(n, dtype=int)

wp.launch(generic_conditional_setter, dim=1, inputs=[ai, 1, 42, False])
wp.launch(generic_conditional_setter, dim=1, inputs=[ai, 1, 5, True])
ai_true = np.zeros(n, dtype=np.int32)
ai_true[1] = 47
assert_np_equal(ai.numpy(), ai_true)


# custom vector/matrix types
my_vec5 = wp.vec(length=5, dtype=wp.float32)
my_mat55 = wp.mat(shape=(5, 5), dtype=wp.float32)
Expand Down Expand Up @@ -509,6 +535,7 @@ class TestGenerics(unittest.TestCase):
add_function_test(TestGenerics, "test_generic_accumulator_kernel", test_generic_accumulator_kernel, devices=devices)
add_function_test(TestGenerics, "test_generic_fill", test_generic_fill, devices=devices)
add_function_test(TestGenerics, "test_generic_fill_overloads", test_generic_fill_overloads, devices=devices)
add_function_test(TestGenerics, "test_generic_conditional_setter", test_generic_conditional_setter, devices=devices)
add_function_test(TestGenerics, "test_generic_transform_kernel", test_generic_transform_kernel, devices=devices)
add_function_test(
TestGenerics, "test_generic_transform_array_kernel", test_generic_transform_array_kernel, devices=devices
Expand Down
4 changes: 2 additions & 2 deletions warp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ class bool:
def __init__(self, x=False):
self.value = x

def __bool__(self) -> bool:
def __bool__(self) -> builtins.bool:
return self.value != 0

def __float__(self) -> float:
Expand Down Expand Up @@ -4799,7 +4799,7 @@ def infer_argument_types(args, template_types, arg_names=None):
arg_types.append(arg_type(dtype=arg.dtype, ndim=arg.ndim))
elif arg_type in warp.types.scalar_and_bool_types:
arg_types.append(arg_type)
elif arg_type in (int, float):
elif arg_type in (int, float, builtins.bool):
# canonicalize type
arg_types.append(warp.types.type_to_warp(arg_type))
elif hasattr(arg_type, "_wp_scalar_type_"):
Expand Down

0 comments on commit f3780ee

Please sign in to comment.