Skip to content

Commit

Permalink
Create is_remote_tensor helper
Browse files Browse the repository at this point in the history
Signed-off-by: Bogdan Pereanu <[email protected]>
  • Loading branch information
pereanub committed Jan 9, 2025
1 parent e83292e commit 98d4555
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,8 @@ class ZeroRemoteTensor final : public RemoteTensor {
bool _external_memory_support = false;
};

inline bool is_remote_tensor(const std::shared_ptr<ov::ITensor>& tensor) {
return std::dynamic_pointer_cast<ZeroRemoteTensor>(tensor) != nullptr;
}

} // namespace intel_npu
16 changes: 5 additions & 11 deletions src/plugins/intel_npu/src/backend/src/zero_infer_request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -467,8 +467,7 @@ void ZeroInferRequest::infer_async() {
auto zeroTensor = std::dynamic_pointer_cast<ZeroTensor>(levelZeroTensor.at(SINGLE_TENSOR));

if (is_batched_input(ioIndex) || inputDescriptor.isShapeTensor || inputDescriptor.isStateInput ||
std::dynamic_pointer_cast<ZeroRemoteTensor>(levelZeroTensor.at(SINGLE_TENSOR)) != nullptr ||
zeroTensor == nullptr) {
is_remote_tensor(levelZeroTensor.at(SINGLE_TENSOR)) || zeroTensor == nullptr) {
++ioIndex;
continue;
}
Expand All @@ -494,8 +493,7 @@ void ZeroInferRequest::infer_async() {
auto zeroTensor = std::dynamic_pointer_cast<ZeroTensor>(levelZeroTensor);

if (outputDescriptor.isShapeTensor || outputDescriptor.isStateOutput ||
std::dynamic_pointer_cast<ZeroRemoteTensor>(levelZeroTensor) != nullptr ||
zeroTensor == nullptr) {
is_remote_tensor(levelZeroTensor) || zeroTensor == nullptr) {
++ioIndex;
continue;
}
Expand Down Expand Up @@ -535,9 +533,7 @@ void ZeroInferRequest::infer_async() {
if (is_batched_input(inputIndex)) {
if (_graph->get_batch_size().has_value()) {
for (size_t i = 0; i < userTensor.size(); i++) {
auto levelZeroBatchRemoteTensor =
std::dynamic_pointer_cast<ZeroRemoteTensor>(get_level_zero_input(inputIndex, i));
if (levelZeroBatchRemoteTensor == nullptr) {
if (!is_remote_tensor(get_level_zero_input(inputIndex, i))) {
void* levelZeroBuffer = get_level_zero_input(inputIndex, i)->data();

auto userBatchRemoteTensor = std::dynamic_pointer_cast<ZeroRemoteTensor>(userTensor.at(i)._ptr);
Expand Down Expand Up @@ -587,8 +583,7 @@ void ZeroInferRequest::infer_async() {
: extract_object(userRemoteTensor->get_properties(), ov::intel_npu::mem_handle);

const auto& levelZeroTensor = get_level_zero_input(inputIndex);
auto levelZeroRemoteTensor = std::dynamic_pointer_cast<ZeroRemoteTensor>(levelZeroTensor);
if (levelZeroRemoteTensor == nullptr) {
if (!is_remote_tensor(levelZeroTensor)) {
void* levelZeroBuffer = levelZeroTensor->data();

if (userBuffer != levelZeroBuffer) {
Expand Down Expand Up @@ -639,8 +634,7 @@ void ZeroInferRequest::get_result() {
: extract_object(userRemoteTensor->get_properties(), ov::intel_npu::mem_handle);

const std::shared_ptr<ov::ITensor>& levelZeroTensor = _levelZeroOutputTensors.at(outputIndex);
auto levelZeroRemoteTensor = std::dynamic_pointer_cast<ZeroRemoteTensor>(levelZeroTensor);
if (levelZeroRemoteTensor == nullptr) {
if (!is_remote_tensor(levelZeroTensor)) {
void* levelZeroBuffer = levelZeroTensor->data();

if (userBuffer != levelZeroBuffer) {
Expand Down

0 comments on commit 98d4555

Please sign in to comment.