Skip to content

Commit

Permalink
Merge branch 'NVIDIA:main' into HydrogenSulfate/add_paddle_backend
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate authored Sep 28, 2024
2 parents fc59eb3 + f357794 commit 6b68b38
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 14 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
- Skip unused functions in module code generation, improving performance.
- Avoid reloading modules if their content does not change, improving performance.
- `wp.Mesh.points` is now a property instead of a raw data member, its reference can be changed after the mesh is initialized.
- Improve error message when invalid objects are referenced in a Warp kernel.

### Fixed

Expand All @@ -56,6 +57,8 @@
- Fix a crash when kernel functions are not found in CPU modules.
- Fix conditions not being evaluated as expected in `while` statements.
- Fix printing Boolean and 8-bit integer values.
- Fix array interface type strings used for Boolean and 8-bit integer values.
- Fix initialization error when setting struct members.

## [1.3.3] - 2024-09-04

Expand Down
5 changes: 4 additions & 1 deletion warp/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,9 @@ def __init__(self, cls, key, module):
elif issubclass(var.type, ctypes.Array):
fields.append((label, var.type))
else:
# HACK: fp16 requires conversion functions from warp.so
if var.type is warp.float16:
warp.init()
fields.append((label, var.type._type_))

class StructType(ctypes.Structure):
Expand Down Expand Up @@ -1647,7 +1650,7 @@ def emit_Name(adj, node):
if isinstance(obj, types.ModuleType):
return obj

raise RuntimeError("Cannot reference a global variable from a kernel unless `wp.constant()` is being used")
raise TypeError(f"Invalid external reference type: {type(obj)}")

@staticmethod
def resolve_type_attribute(var_type: type, attr: str):
Expand Down
20 changes: 20 additions & 0 deletions warp/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2590,6 +2590,25 @@ def test_array_from_int64_domain(test, device):
wp.zeros(np.array([1504, 1080, 520], dtype=np.int64), dtype=wp.float32, device=device)


def test_numpy_array_interface(test, device):
# We should be able to convert between NumPy and Warp arrays using __array_interface__ on CPU.
# This tests all scalar types supported by both.

n = 10

scalar_types = wp.types.scalar_types

for dtype in scalar_types:
# test round trip
a1 = wp.zeros(n, dtype=dtype, device="cpu")
na = np.array(a1)
a2 = wp.array(na, device="cpu")

assert a1.dtype == a2.dtype
assert a1.shape == a2.shape
assert a1.strides == a2.strides


devices = get_test_devices()


Expand Down Expand Up @@ -2648,6 +2667,7 @@ def test_array_new_del(self):
add_function_test(TestArray, "test_array_of_structs_roundtrip", test_array_of_structs_roundtrip, devices=devices)
add_function_test(TestArray, "test_array_from_numpy", test_array_from_numpy, devices=devices)
add_function_test(TestArray, "test_array_aliasing_from_numpy", test_array_aliasing_from_numpy, devices=["cpu"])
add_function_test(TestArray, "test_numpy_array_interface", test_numpy_array_interface, devices=["cpu"])

add_function_test(TestArray, "test_array_inplace_ops", test_array_inplace_ops, devices=devices)
add_function_test(TestArray, "test_direct_from_numpy", test_direct_from_numpy, devices=["cpu"])
Expand Down
12 changes: 6 additions & 6 deletions warp/tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,22 +405,22 @@ def kernel_3_fn(

kernel = wp.Kernel(func=kernel_1_fn)
with test.assertRaisesRegex(
RuntimeError,
r"Cannot reference a global variable from a kernel unless `wp.constant\(\)` is being used",
TypeError,
r"Invalid external reference type: <class 'warp.types.array'>",
):
wp.launch(kernel, dim=out.shape, inputs=(), outputs=(out,), device=device)

kernel = wp.Kernel(func=kernel_2_fn)
with test.assertRaisesRegex(
RuntimeError,
r"Cannot reference a global variable from a kernel unless `wp.constant\(\)` is being used",
TypeError,
r"Invalid external reference type: <class 'warp.types.array'>",
):
wp.launch(kernel, dim=out.shape, inputs=(), outputs=(out,), device=device)

kernel = wp.Kernel(func=kernel_3_fn)
with test.assertRaisesRegex(
RuntimeError,
r"Cannot reference a global variable from a kernel unless `wp.constant\(\)` is being used",
TypeError,
r"Invalid external reference type: <class 'warp.types.array'>",
):
wp.launch(kernel, dim=out.shape, inputs=(), outputs=(out,), device=device)

Expand Down
28 changes: 28 additions & 0 deletions warp/tests/test_implicit_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,34 @@ class TestImplicitInitIsPeerAccessSupported(unittest.TestCase):
)


# Structs
# ------------------------------------------------------------------------------


def test_struct_member_init(test, device):
@wp.struct
class S:
# fp16 requires conversion functions from warp.so
x: wp.float16
v: wp.vec3h

s = S()
s.x = 42.0
s.v = wp.vec3h(1.0, 2.0, 3.0)


class TestImplicitInitStructMemberInit(unittest.TestCase):
pass


add_function_test(
TestImplicitInitStructMemberInit,
"test_struct_member_init",
test_struct_member_init,
check_output=False,
)


if __name__ == "__main__":
# Do not clear the kernel cache or call anything that would initialize Warp
# since these tests are specifically aiming to catch issues where Warp isn't
Expand Down
24 changes: 24 additions & 0 deletions warp/tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,27 @@ def wrap_vec_tensor_with_warp_grad(vec_dtype):
wrap_vec_tensor_with_warp_grad(wp.transform)


def test_cuda_array_interface(test, device):
# We should be able to construct Torch tensors from Warp arrays via __cuda_array_interface__ on GPU.
# Note that Torch does not support __array_interface__ on CPU.

torch_device = wp.device_to_torch(device)
n = 10

# test the types supported by both Warp and Torch
scalar_types = [wp.float16, wp.float32, wp.float64, wp.int8, wp.int16, wp.int32, wp.int64, wp.uint8]

for dtype in scalar_types:
# test round trip
a1 = wp.zeros(n, dtype=dtype, device=device)
t = torch.tensor(a1, device=torch_device)
a2 = wp.array(t, device=device)

assert a1.dtype == a2.dtype
assert a1.shape == a2.shape
assert a1.strides == a2.strides


def test_to_torch(test, device):
import torch

Expand Down Expand Up @@ -918,6 +939,9 @@ class TestTorch(unittest.TestCase):
test_warp_graph_torch_stream,
devices=torch_compatible_cuda_devices,
)
add_function_test(
TestTorch, "test_cuda_array_interface", test_cuda_array_interface, devices=torch_compatible_cuda_devices
)

# multi-GPU tests
if len(torch_compatible_cuda_devices) > 1:
Expand Down
2 changes: 1 addition & 1 deletion warp/tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def test_constant(self):
self.assertEqual(const, wp.vec3i(1, 2, 3))

def test_constant_error_invalid_type(self):
with self.assertRaisesRegex(RuntimeError, r"Invalid constant type: <class 'tuple'>$"):
with self.assertRaisesRegex(TypeError, r"Invalid constant type: <class 'tuple'>$"):
wp.constant((1, 2, 3))

def test_vector_assign(self):
Expand Down
12 changes: 6 additions & 6 deletions warp/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,8 @@ def constant(x):
x: Compile-time constant value, can be any of the built-in math types.
"""

if not isinstance(x, (builtins.bool, int, float, tuple(scalar_and_bool_types), ctypes.Array)):
raise RuntimeError(f"Invalid constant type: {type(x)}")
if not is_value(x):
raise TypeError(f"Invalid constant type: {type(x)}")

return x

Expand Down Expand Up @@ -1302,17 +1302,17 @@ def type_to_warp(dtype):

def type_typestr(dtype):
if dtype == bool:
return "?"
return "|b1"
elif dtype == float16:
return "<f2"
elif dtype == float32:
return "<f4"
elif dtype == float64:
return "<f8"
elif dtype == int8:
return "b"
return "|i1"
elif dtype == uint8:
return "B"
return "|u1"
elif dtype == int16:
return "<i2"
elif dtype == uint16:
Expand Down Expand Up @@ -1384,7 +1384,7 @@ def type_is_matrix(t):

# returns true for all value types (int, float, bool, scalars, vectors, matrices)
def type_is_value(x):
return x in value_types or issubclass(x, ctypes.Array)
return x in value_types or hasattr(x, "_wp_scalar_type_")


# equivalent of the above but for values
Expand Down

0 comments on commit 6b68b38

Please sign in to comment.