diff --git a/paddle/phi/kernels/cpu/numel_kernel.cc b/paddle/phi/kernels/cpu/numel_kernel.cc index c7d56efc207813..29c8595b0c31a7 100644 --- a/paddle/phi/kernels/cpu/numel_kernel.cc +++ b/paddle/phi/kernels/cpu/numel_kernel.cc @@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(numel, CPU, ALL_LAYOUT, phi::NumelKernel, + int8_t, uint8_t, int16_t, int, diff --git a/paddle/phi/kernels/gpu/numel_kernel.cu b/paddle/phi/kernels/gpu/numel_kernel.cu index 1f760dbf3ad68e..1a8ec3158e8ee5 100644 --- a/paddle/phi/kernels/gpu/numel_kernel.cu +++ b/paddle/phi/kernels/gpu/numel_kernel.cu @@ -22,6 +22,8 @@ PD_REGISTER_KERNEL(numel, GPU, ALL_LAYOUT, phi::NumelKernel, + int8_t, + uint8_t, int16_t, int, int64_t, diff --git a/python/paddle/base/dygraph/tensor_patch_methods.py b/python/paddle/base/dygraph/tensor_patch_methods.py index 2687353c29ee1f..0bb12a42ede6f4 100644 --- a/python/paddle/base/dygraph/tensor_patch_methods.py +++ b/python/paddle/base/dygraph/tensor_patch_methods.py @@ -124,6 +124,7 @@ def _to_static_var(self, to_parameter=False, **kwargs): 'grad_', 'strides', 'offset', + '__cuda_array_interface__', ] param_keys = ['stop_gradient', 'trainable'] if isinstance(self, EagerParamBase): @@ -1248,6 +1249,79 @@ def coalesce(self: Tensor, name: str | None = None) -> Tensor: """ return _C_ops.sparse_coalesce(self) + @property + def __cuda_array_interface__(self): + """Array view description for cuda tensors. + + See: + CUDA Array Interface (Version 2) + https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html + """ + + # raise AttributeError for unsupported tensors, so that + # hasattr(cpu_tensor, "__cuda_array_interface__") is False. + if "gpu" not in str(self.place): + raise AttributeError( + "Can't get __cuda_array_interface__ on non-CUDA tensor. " + "If CUDA data is required use tensor.cuda() to copy tensor to device memory." + ) + + if self.is_sparse(): + raise AttributeError( + "Can't get __cuda_array_interface__ on sparse tensor. " + "Use Tensor.to_dense() to convert to a dense tensor first." + ) + + # RuntimeError, matching tensor.__array__() behavior. + if not self.stop_gradient: + raise RuntimeError( + "Can't get __cuda_array_interface__ on Tensor that requires grad. " + "If gradients aren't required, use var.detach() to get Tensor that doesn't require grad." + ) + + # CUDA devices are little-endian and tensors are stored in native byte + # order. 1-byte entries are endian-agnostic. + typestr = { + paddle.complex64: " 0 else 0 + data = (data_ptr, False) # read-only is false + + return { + "typestr": typestr, + "shape": shape, + "strides": strides, + "data": data, + "version": 2, + } + if not hasattr(core, "eager"): return @@ -1290,6 +1364,7 @@ def coalesce(self: Tensor, name: str | None = None) -> Tensor: ("__hash__", __hash__), ("_use_gpudnn", _use_gpudnn), ("_md5sum", _md5sum), + ("__cuda_array_interface__", __cuda_array_interface__), ): setattr(core.eager.Tensor, method_name, method) diff --git a/test/dygraph_to_static/test_tensor_attr_consistency.py b/test/dygraph_to_static/test_tensor_attr_consistency.py index 226e58d78ee007..60d753361c3222 100644 --- a/test/dygraph_to_static/test_tensor_attr_consistency.py +++ b/test/dygraph_to_static/test_tensor_attr_consistency.py @@ -77,6 +77,7 @@ 'tolist', 'value', 'zero_', + "__cuda_array_interface__", ] ) STATIC_ONLY_TENSOR_ATTRS_ALLOW_LIST = OrderedSet( diff --git a/test/legacy_test/test_eager_tensor.py b/test/legacy_test/test_eager_tensor.py index 13b989ba8716b0..49658c2a50035a 100644 --- a/test/legacy_test/test_eager_tensor.py +++ b/test/legacy_test/test_eager_tensor.py @@ -16,6 +16,7 @@ import unittest import numpy as np +from utils import dygraph_guard import paddle import paddle.nn.functional as F @@ -1187,6 +1188,96 @@ def test_print_tensor_dtype(self): self.assertEqual(a_str, expected) + def test___cuda_array_interface__(self): + """test Tensor.__cuda_array_interface__""" + with dygraph_guard(): + # raise AttributeError for cpu tensor. + cpu_place = paddle.CPUPlace() + cpu_tensor = paddle.rand([3, 3]).to(device=cpu_place) + self.assertRaises( + AttributeError, + getattr, + cpu_tensor, + '__cuda_array_interface__', + ) + + if paddle.device.is_compiled_with_cuda(): + gpu_place = paddle.CUDAPlace(0) + # raise AttributeError for sparse tensor. + sparse_tensor = ( + paddle.rand([3, 3]).to(device=gpu_place).to_sparse_coo(2) + ) + self.assertRaises( + AttributeError, + getattr, + sparse_tensor, + '__cuda_array_interface__', + ) + + # strides should be None if contiguous + tensor = paddle.randn([3, 3]).to(device=gpu_place) + interface = tensor.__cuda_array_interface__ + assert interface["strides"] is None + + # strides should be tuple of int if not contiguous + tensor = paddle.randn([10, 10]).to(device=gpu_place) + tensor = tensor[::2] + interface = tensor.__cuda_array_interface__ + assert interface["strides"] == (80, 4) + + # data_ptr should be 0 if tensor is 0-size + tensor = paddle.randn([0, 10]).to(device=gpu_place) + interface = tensor.__cuda_array_interface__ + assert interface["data"][0] == 0 + + # raise AttributeError for tensor that requires grad. + tensor = paddle.randn([3, 3]).to(device=gpu_place) + tensor.stop_gradient = False + self.assertRaises( + RuntimeError, + getattr, + tensor, + '__cuda_array_interface__', + ) + + # check supports of dtypes + for dtype in [ + paddle.complex64, + paddle.complex128, + paddle.bfloat16, + paddle.float16, + paddle.float32, + paddle.float64, + paddle.uint8, + paddle.int8, + paddle.int16, + paddle.int32, + paddle.int64, + paddle.bool, + ]: + tensor = ( + paddle.uniform([10, 10], min=-10.0, max=10.0) + .to(device=gpu_place) + .astype(dtype) + ) + interface = tensor.__cuda_array_interface__ + assert "typestr" in interface and isinstance( + interface["typestr"], str + ) + assert "shape" in interface and isinstance( + interface["shape"], tuple + ) + assert "strides" in interface and ( + isinstance(interface["strides"], tuple) + or interface["strides"] is None + ) + assert ( + "data" in interface + and isinstance(interface["data"], tuple) + and len(interface["data"]) == 2 + ) + assert "version" in interface and interface["version"] == 2 + class TestEagerTensorSetitem(unittest.TestCase): def func_setUp(self): diff --git a/test/legacy_test/test_numel_op.py b/test/legacy_test/test_numel_op.py index 8bfbed64ef329c..9d7b128d533afe 100644 --- a/test/legacy_test/test_numel_op.py +++ b/test/legacy_test/test_numel_op.py @@ -71,6 +71,18 @@ def init(self): self.shape = (0,) +class TestNumelOp1int8(TestNumelOp): + def init(self): + self.dtype = np.int8 + self.shape = (11, 66) + + +class TestNumelOp2int8(TestNumelOp): + def init(self): + self.dtype = np.int8 + self.shape = (0,) + + class TestNumelOpComplex(TestNumelOp): def setUp(self): self.op_type = "size"