Skip to content

Commit

Permalink
Reuse GELU implementation from PyTorch core
Browse files Browse the repository at this point in the history
Pull Request resolved: #7041

kernels/optimized doesn't need to support embedded systems, so it can just take a header-only dep on PyTorch.

Note that, because we will pick up Sleef internally and ignore it
externally thanks to ATen vec, this PR gets to enable optimized GELU in OSS.

Testing: CI to make sure this doesn't break mobile build modes; happy to take advice on anything not currently covered that might break.
ghstack-source-id: 259204954
@exported-using-ghexport

Differential Revision: [D66335522](https://our.internmc.facebook.com/intern/diff/D66335522/)
  • Loading branch information
swolchok committed Dec 20, 2024
1 parent 0f5423f commit f7cdf1d
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 42 deletions.
2 changes: 2 additions & 0 deletions kernels/optimized/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ message("Generated files ${gen_command_sources}")

list(TRANSFORM _optimized_kernels__srcs PREPEND "${EXECUTORCH_ROOT}/")
add_library(optimized_kernels ${_optimized_kernels__srcs})
find_package(Torch CONFIG REQUIRED)
target_include_directories(optimized_kernels PRIVATE ${TORCH_INCLUDE_DIRS})
target_link_libraries(
optimized_kernels PRIVATE executorch_core cpublas extension_threadpool
)
Expand Down
51 changes: 15 additions & 36 deletions kernels/optimized/cpu/op_gelu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <cmath>

#include <ATen/native/cpu/Gelu.h>
#include <executorch/runtime/kernel/kernel_includes.h>
#include <executorch/runtime/platform/assert.h>

Expand Down Expand Up @@ -46,48 +47,26 @@ void gelu(
CTYPE* out_data = output.mutable_data_ptr<CTYPE>();
size_t lim = input.numel();

// TODO: Add fast path for tanh using sleef's tanh
if (approximate == "tanh") {
// 0.5 * x * (1 + Tanh(sqrt(2 / pi) * (x + 0.044715 * x^3))
for (size_t i = 0; i < lim; ++i) {
const CTYPE x = in_data[i];
const CTYPE kBeta = M_SQRT2 * M_2_SQRTPI * 0.5;
const CTYPE kKappa = 0.044715;
auto x_cube = x * x * x;
auto inner = kBeta * (x + kKappa * x_cube);
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::tanh(inner));
using Vec = at::vec::Vectorized<CTYPE>;
int i = 0;
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
Vec x = Vec::loadu(in_data + i);
at::native::vectorized_gelu_approximated_with_tanh(x).store(out_data + i);
}
} else if (approximate == "none") { // dont appx
// GELU(x) = x * Φ(x) where Φ(x) is the is the Cumulative Distribution
// Function for Gaussian Distribution.

#ifndef __aarch64__
for (size_t i = 0; i < lim; ++i) {
const CTYPE x = in_data[i];
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
for (; i < lim; ++i) {
out_data[i] = at::native::scalar_gelu_approximated_with_tanh(in_data[i]);
}
#else
size_t i = 0;
if (std::is_same<CTYPE, float>::value) {
for (; i + 4 < lim; i += 4) {
const float32x4_t in =
vld1q_f32(static_cast<const float*>(&in_data[i]));
const float32x4_t m_sqrt1_2x4 = {
M_SQRT1_2, M_SQRT1_2, M_SQRT1_2, M_SQRT1_2};
const float32x4_t ones = vmovq_n_f32(1.0);
const float32x4_t halves = vmovq_n_f32(0.5);
float32x4_t out = Sleef_erff4_u10(vmulq_f32(in, m_sqrt1_2x4));
vst1q_f32(
static_cast<float*>(&out_data[i]),
vmulq_f32(vmulq_f32(vaddq_f32(out, ones), in), halves));
}
} else if (approximate == "none") {
using Vec = at::vec::Vectorized<CTYPE>;
int i = 0;
for (; i < lim - (lim % Vec::size()); i += Vec::size()) {
Vec x = Vec::loadu(in_data + i);
at::native::vectorized_gelu(x).store(out_data + i);
}
for (; i < lim; ++i) {
const CTYPE x = in_data[i];
out_data[i] = CTYPE(0.5) * x * (CTYPE(1) + std::erf(x * M_SQRT1_2));
out_data[i] = at::native::scalar_gelu(in_data[i]);
}
#endif // __aarch64__

} else {
ET_KERNEL_CHECK_MSG(
context,
Expand Down
16 changes: 10 additions & 6 deletions kernels/optimized/cpu/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,9 @@ _OPTIMIZED_ATEN_OPS = (
op_target(name = "op_sigmoid"),
op_target(
name = "op_gelu",
deps = select({
"DEFAULT": [],
"ovr_config//cpu:arm64": [
"fbsource//third-party/sleef:sleef_arm",
],
}),
deps = [
":aten_headers_for_executorch",
],
),
op_target(
name = "op_le",
Expand Down Expand Up @@ -94,6 +91,13 @@ _OPTIMIZED_ATEN_OPS = (
),
)


def get_sleef_preprocessor_flags():
if runtime.is_oss:
return []
return ["-DAT_BUILD_ARM_VEC256_WITH_SLEEF"]


def define_common_targets():
"""Defines targets that should be shared between fbcode and xplat.
Expand Down
5 changes: 5 additions & 0 deletions kernels/optimized/optimized-oss.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@
- arg_meta: null
kernel_name: torch::executor::opt_sigmoid_out

- op: gelu.out
kernels:
- arg_meta: null
kernel_name: torch::executor::opt_gelu_out

- op: le.Scalar_out
kernels:
- arg_meta: null
Expand Down

0 comments on commit f7cdf1d

Please sign in to comment.