Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API] Add __cuda_array_interface__ for eager Tensor #68192

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
b7d04aa
Merge pull request #230 from PaddlePaddle/develop
HydrogenSulfate May 10, 2024
2fd9dc0
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate May 10, 2024
4c5afe2
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate May 15, 2024
056d19b
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate May 15, 2024
c022e44
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate May 31, 2024
d723c27
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Jun 6, 2024
04664b8
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Jun 6, 2024
2f2777c
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Jun 19, 2024
36efc60
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Jul 4, 2024
6d3d314
Merge pull request #268 from PaddlePaddle/develop
HydrogenSulfate Jul 4, 2024
8eed6d0
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Jul 4, 2024
f6815d3
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Jul 12, 2024
1b3a43b
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Jul 16, 2024
9550534
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Jul 22, 2024
0053ffb
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Jul 24, 2024
928d668
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Jul 24, 2024
2c3ba4b
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Sep 9, 2024
a993efa
Merge branch 'develop' of https://github.com/HydrogenSulfate/Paddle i…
HydrogenSulfate Sep 10, 2024
8e40e50
add new patch method 'cuda_array_interface' for intergrated with numba
HydrogenSulfate Sep 12, 2024
88eddaa
support numelkernel for int8 and update cuda_array_interface to v2 an…
HydrogenSulfate Sep 13, 2024
bf25e63
add __cuda_array_interface__ to attr_not_need_keys
HydrogenSulfate Sep 13, 2024
970f0fe
Merge branch 'PaddlePaddle:develop' into add___cuda_array_interface__
HydrogenSulfate Sep 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/numel_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ PD_REGISTER_KERNEL(numel,
CPU,
ALL_LAYOUT,
phi::NumelKernel,
int8_t,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

numel 中英文 API 文档是不是也应该体现这一点修改?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感觉可以跟下个PR一起改掉

uint8_t,
int16_t,
int,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/kernels/gpu/numel_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ PD_REGISTER_KERNEL(numel,
GPU,
ALL_LAYOUT,
phi::NumelKernel,
int8_t,
uint8_t,
int16_t,
int,
int64_t,
Expand Down
75 changes: 75 additions & 0 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def _to_static_var(self, to_parameter=False, **kwargs):
'grad_',
'strides',
'offset',
'__cuda_array_interface__',
]
param_keys = ['stop_gradient', 'trainable']
if isinstance(self, EagerParamBase):
Expand Down Expand Up @@ -1248,6 +1249,79 @@ def coalesce(self: Tensor, name: str | None = None) -> Tensor:
"""
return _C_ops.sparse_coalesce(self)

@property
def __cuda_array_interface__(self):
HydrogenSulfate marked this conversation as resolved.
Show resolved Hide resolved
"""Array view description for cuda tensors.

See:
CUDA Array Interface (Version 2)
https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html
"""

# raise AttributeError for unsupported tensors, so that
# hasattr(cpu_tensor, "__cuda_array_interface__") is False.
if "gpu" not in str(self.place):
HydrogenSulfate marked this conversation as resolved.
Show resolved Hide resolved
raise AttributeError(
"Can't get __cuda_array_interface__ on non-CUDA tensor. "
"If CUDA data is required use tensor.cuda() to copy tensor to device memory."
)

if self.is_sparse():
raise AttributeError(
"Can't get __cuda_array_interface__ on sparse tensor. "
"Use Tensor.to_dense() to convert to a dense tensor first."
)

# RuntimeError, matching tensor.__array__() behavior.
if not self.stop_gradient:
raise RuntimeError(
"Can't get __cuda_array_interface__ on Tensor that requires grad. "
"If gradients aren't required, use var.detach() to get Tensor that doesn't require grad."
)

# CUDA devices are little-endian and tensors are stored in native byte
# order. 1-byte entries are endian-agnostic.
typestr = {
paddle.complex64: "<c8",
paddle.complex128: "<c16",
paddle.bfloat16: "<f2",
paddle.float16: "<f2",
paddle.float32: "<f4",
paddle.float64: "<f8",
paddle.uint8: "|u1",
paddle.int8: "|i1",
paddle.int16: "<i2",
paddle.int32: "<i4",
paddle.int64: "<i8",
paddle.bool: "|b1",
# NOTE: Paddle not support uint32, uint64, uint16 yet.
# paddle.uint16: "<u2",
# paddle.uint32: "<u4",
# paddle.uint64: "<u8",
}[self.dtype]

itemsize = self.element_size()

shape = tuple(self.shape)
if self.is_contiguous():
# __cuda_array_interface__ v2 requires the strides to be omitted
# (either not set or set to None) for C-contiguous arrays.
strides = None
else:
# the number of bytes to skip to access the next element at each dimension.
strides = tuple(s * itemsize for s in self.strides)

data_ptr = self.data_ptr() if self.numel().item() > 0 else 0
data = (data_ptr, False) # read-only is false

return {
"typestr": typestr,
"shape": shape,
"strides": strides,
"data": data,
"version": 2,
}

if not hasattr(core, "eager"):
return

Expand Down Expand Up @@ -1290,6 +1364,7 @@ def coalesce(self: Tensor, name: str | None = None) -> Tensor:
("__hash__", __hash__),
("_use_gpudnn", _use_gpudnn),
("_md5sum", _md5sum),
("__cuda_array_interface__", __cuda_array_interface__),
):
setattr(core.eager.Tensor, method_name, method)

Expand Down
1 change: 1 addition & 0 deletions test/dygraph_to_static/test_tensor_attr_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
'tolist',
'value',
'zero_',
"__cuda_array_interface__",
]
)
STATIC_ONLY_TENSOR_ATTRS_ALLOW_LIST = OrderedSet(
Expand Down
91 changes: 91 additions & 0 deletions test/legacy_test/test_eager_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import unittest

import numpy as np
from utils import dygraph_guard

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -1187,6 +1188,96 @@ def test_print_tensor_dtype(self):

self.assertEqual(a_str, expected)

def test___cuda_array_interface__(self):
"""test Tensor.__cuda_array_interface__"""
with dygraph_guard():
# raise AttributeError for cpu tensor.
cpu_place = paddle.CPUPlace()
cpu_tensor = paddle.rand([3, 3]).to(device=cpu_place)
self.assertRaises(
AttributeError,
getattr,
cpu_tensor,
'__cuda_array_interface__',
)

if paddle.device.is_compiled_with_cuda():
gpu_place = paddle.CUDAPlace(0)
# raise AttributeError for sparse tensor.
sparse_tensor = (
paddle.rand([3, 3]).to(device=gpu_place).to_sparse_coo(2)
)
self.assertRaises(
AttributeError,
getattr,
sparse_tensor,
'__cuda_array_interface__',
)

# strides should be None if contiguous
tensor = paddle.randn([3, 3]).to(device=gpu_place)
interface = tensor.__cuda_array_interface__
assert interface["strides"] is None
HydrogenSulfate marked this conversation as resolved.
Show resolved Hide resolved

# strides should be tuple of int if not contiguous
tensor = paddle.randn([10, 10]).to(device=gpu_place)
tensor = tensor[::2]
interface = tensor.__cuda_array_interface__
assert interface["strides"] == (80, 4)

# data_ptr should be 0 if tensor is 0-size
tensor = paddle.randn([0, 10]).to(device=gpu_place)
interface = tensor.__cuda_array_interface__
assert interface["data"][0] == 0

# raise AttributeError for tensor that requires grad.
tensor = paddle.randn([3, 3]).to(device=gpu_place)
tensor.stop_gradient = False
self.assertRaises(
RuntimeError,
getattr,
tensor,
'__cuda_array_interface__',
)

# check supports of dtypes
for dtype in [
paddle.complex64,
paddle.complex128,
paddle.bfloat16,
paddle.float16,
paddle.float32,
paddle.float64,
paddle.uint8,
paddle.int8,
paddle.int16,
paddle.int32,
paddle.int64,
paddle.bool,
]:
tensor = (
paddle.uniform([10, 10], min=-10.0, max=10.0)
.to(device=gpu_place)
.astype(dtype)
)
interface = tensor.__cuda_array_interface__
assert "typestr" in interface and isinstance(
interface["typestr"], str
)
assert "shape" in interface and isinstance(
interface["shape"], tuple
)
assert "strides" in interface and (
isinstance(interface["strides"], tuple)
or interface["strides"] is None
)
assert (
"data" in interface
and isinstance(interface["data"], tuple)
and len(interface["data"]) == 2
)
assert "version" in interface and interface["version"] == 2


class TestEagerTensorSetitem(unittest.TestCase):
def func_setUp(self):
Expand Down
12 changes: 12 additions & 0 deletions test/legacy_test/test_numel_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,18 @@ def init(self):
self.shape = (0,)


class TestNumelOp1int8(TestNumelOp):
def init(self):
self.dtype = np.int8
self.shape = (11, 66)


class TestNumelOp2int8(TestNumelOp):
def init(self):
self.dtype = np.int8
self.shape = (0,)


class TestNumelOpComplex(TestNumelOp):
def setUp(self):
self.op_type = "size"
Expand Down