diff --git a/onnxruntime/core/providers/webgpu/tensor/gather_elements.cc b/onnxruntime/core/providers/webgpu/tensor/gather_elements.cc new file mode 100644 index 0000000000000..00d8caf2624a9 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather_elements.cc @@ -0,0 +1,86 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/inlined_containers.h" +#include "core/providers/webgpu/tensor/gather_elements.h" +#include "core/providers/cpu/tensor/utils.h" +#include "core/providers/webgpu/shader_helper.h" +#include "core/providers/webgpu/webgpu_supported_types.h" + +namespace onnxruntime { +namespace webgpu { + +ONNX_OPERATOR_VERSIONED_KERNEL_EX( + GatherElements, + kOnnxDomain, + 11, 12, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + GatherElements); + +ONNX_OPERATOR_KERNEL_EX( + GatherElements, + kOnnxDomain, + 13, + kWebGpuExecutionProvider, + (*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()), + GatherElements); + +Status GatherElementsProgram::GenerateShaderCode(ShaderHelper& shader) const { + const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform); + const ShaderVariableHelper& indices = shader.AddInput("indices", ShaderUsage::UseUniform); + const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform); + + shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size") + << "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n" + << "var idx = " << indices.GetByOffset("global_idx") << ";\n" + << "if (idx < 0) {\n" + << " idx = idx + uniforms.axis_dim_limit;\n" + << "}\n" + << "var input_indices = output_indices;\n" + << input.IndicesSet("input_indices", "uniforms.axis", "u32(idx)") << ";\n" + << "let value = " << input.GetByIndices("input_indices") << ";\n" + << output.SetByOffset("global_idx", "value") << ";\n"; + + return Status::OK(); +} + +Status GatherElements::ComputeInternal(ComputeContext& context) const { + const auto* input_tensor = context.Input(0); + const TensorShape& input_shape = input_tensor->Shape(); + int64_t input_rank = input_shape.NumDimensions(); + + const auto* indices_tensor = context.Input(1); + const TensorShape& indices_shape = indices_tensor->Shape(); + + // Handle negative axis + int64_t axis = axis_; + if (axis < 0) { + axis += input_rank; + } + + auto axis_dim_limit = input_shape[axis]; + + auto output_dims = indices_shape.AsShapeVector(); + TensorShape output_shape(output_dims); + auto* output_tensor = context.Output(0, output_shape); + int64_t output_size = output_tensor->Shape().Size(); + + if (output_size == 0) { + return Status::OK(); + } + + GatherElementsProgram program{}; + program + .AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddInputs({{indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}}) + .AddOutputs({output_tensor}) + .SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE) + .AddUniformVariables({{static_cast(output_size)}, + {static_cast(axis_dim_limit)}, + {static_cast(axis)}}); + return context.RunProgram(program); +} + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/tensor/gather_elements.h b/onnxruntime/core/providers/webgpu/tensor/gather_elements.h new file mode 100644 index 0000000000000..f70bbda84c933 --- /dev/null +++ b/onnxruntime/core/providers/webgpu/tensor/gather_elements.h @@ -0,0 +1,36 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/webgpu/webgpu_kernel.h" +#include "core/providers/webgpu/program.h" + +namespace onnxruntime { +namespace webgpu { + +class GatherElementsProgram final : public Program { + public: + GatherElementsProgram() : Program{"GatherElements"} {} + + Status GenerateShaderCode(ShaderHelper& sh) const override; + + WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32}, + {"axis_dim_limit", ProgramUniformVariableDataType::Int32}, + {"axis", ProgramUniformVariableDataType::Int32}); +}; + +class GatherElements final : public WebGpuKernel { + public: + GatherElements(const OpKernelInfo& info) : WebGpuKernel(info) { + axis_ = info.GetAttrOrDefault("axis", 0); + } + + Status ComputeInternal(ComputeContext& context) const override; + + private: + int64_t axis_; +}; + +} // namespace webgpu +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc index 66209adf6f1a9..295a8de31ed50 100644 --- a/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc +++ b/onnxruntime/core/providers/webgpu/webgpu_execution_provider.cc @@ -649,8 +649,8 @@ std::unique_ptr RegisterKernels() { BuildKernelCreateInfo, BuildKernelCreateInfo, - // BuildKernelCreateInfo, - // BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // BuildKernelCreateInfo, // BuildKernelCreateInfo, diff --git a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc index 5b2d00bb956bf..81e51375b9992 100644 --- a/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/gather_elements_op_test.cc @@ -389,9 +389,10 @@ TEST(GatherElementsOpTest, IndicesOutOfBounds) { // skip openvino which will not throw error message but will ensure no out-of-bound access // skip TensorRT because it doesn't support out of bounds indices // skip QNN because it doesn't support out of bounds indices + // skip WebGPU because it doesn't support out of bounds indices test.Run(OpTester::ExpectResult::kExpectFailure, "", {kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider, - kTensorrtExecutionProvider, kDmlExecutionProvider, kQnnExecutionProvider}); + kTensorrtExecutionProvider, kDmlExecutionProvider, kQnnExecutionProvider, kWebGpuExecutionProvider}); } TEST(GatherElementsOpTest, BigIndices) {