Skip to content

Commit

Permalink
refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Jan 10, 2025
1 parent 64381ae commit ee01f92
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 78 deletions.
70 changes: 53 additions & 17 deletions onnxruntime/core/providers/cpu/nn/layer_norm_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "core/framework/tensor_shape.h"
#include "core/common/status.h"
#include "core/common/narrow.h"

namespace onnxruntime {

Expand All @@ -14,24 +15,57 @@ constexpr const char* kLayerNormInputShapeMismatchError =

constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be larger than 1, got ";

constexpr int64_t kLayerNormInvalidInput = -1;

struct LayerNormParams {
int64_t num_rows;
int64_t norm_size; // size per row
int64_t scale_size;
int64_t bias_size;
int64_t broadcast_param;
};

// When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
// We support scale and bias shape like below:
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
// Below is a macro to compute the initial index for scale and bias data.
#ifndef LAYER_NORM_SCALE_BIAS_OFFSET
#define LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, row_idx, norm_size) \
((broadcast_param == 0) ? 0 \
: norm_size * (broadcast_param > 0 ? row_idx / broadcast_param : row_idx % (-broadcast_param)))
#endif

class LayerNormHelper {
public:
static Status CheckBroadcast(const TensorShape& x_shape,
const TensorShape& scale_shape,
const TensorShape& bias_shape,
bool has_bias,
int64_t axis,
int64_t& broadcast_param) {
broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis);
if (broadcast_param == 0) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
kLayerNormInputShapeMismatchError,
" X.shape=", x_shape,
" scale.shape=", scale_shape,
" bias.shape=", bias_shape,
" and axis=", axis);
}
static Status CheckInputs(const TensorShape& x_shape,
const TensorShape& scale_shape,
const TensorShape& bias_shape,
bool has_bias,
int64_t axis,
LayerNormParams& params) {
params.num_rows = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
params.norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));
params.scale_size = scale_shape.Size();
params.bias_size = bias_shape.Size();
params.broadcast_param = 0;

if (params.norm_size <= 1) {
params.broadcast_param = kLayerNormInvalidInput;
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, params.norm_size);
} else if (params.scale_size != params.norm_size || (has_bias && params.bias_size != params.scale_size)) {
params.broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis);
if (params.broadcast_param == kLayerNormInvalidInput) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
kLayerNormInputShapeMismatchError,
" X.shape=", x_shape,
" scale.shape=", scale_shape,
" bias.shape=", bias_shape,
" and axis=", axis);
}
}
return Status::OK();
}

Expand All @@ -47,7 +81,8 @@ class LayerNormHelper {
(bias_shape == nullptr || *bias_shape == scale_shape)) {
for (size_t i = 2; i < x_shape.NumDimensions(); ++i) {
if (x_shape.GetDims()[i] != scale_shape.GetDims()[i]) {
return 0;
// scale cannot be broadcasted to X. It is invalid input.
return kLayerNormInvalidInput;
}
}

Expand All @@ -69,7 +104,8 @@ class LayerNormHelper {
}
}

return 0;
// Other cases that are not supported.
return kLayerNormInvalidInput;
}
};

Expand Down
46 changes: 10 additions & 36 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,8 @@ void ComputeJob(
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

// When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
// We support scale and bias shape like below:
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
// Here we compute the initial index for scale and bias data.
int64_t i = (broadcast_param == 0)
? 0
: norm_size * (broadcast_param > 0 ? (task_idx / broadcast_param) : (task_idx % (-broadcast_param)));
// Compute the offset of gamma and beta to support broadcasting.
int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size);

for (int64_t h = 0; h < norm_size; h++, i++) {
if (simplified) {
Expand Down Expand Up @@ -134,16 +126,8 @@ void ComputeJob(
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

// When X shape is (B, S, ...), and task_idx is in the range of [0, B * S).
// We support scale and bias shape like below:
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
// Here we compute the initial index for scale and bias data.
int64_t i = (broadcast_param == 0)
? 0
: norm_size * (broadcast_param > 0 ? (task_idx / broadcast_param) : (task_idx % (-broadcast_param)));
// Compute the offset of gamma and beta to support broadcasting.
int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size);

for (size_t h = 0; h < num_elems; h++, i++) {
if (simplified) {
Expand Down Expand Up @@ -283,38 +267,28 @@ Status LayerNormImpl::ComputeWithoutContext(
float epsilon,
bool simplified,
AllocatorPtr alloc) const {
int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));

int64_t scale_size = scale_shape.Size();
int64_t bias_size = bias_shape.Size();
int64_t broadcast_param = 0;

if (norm_size <= 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, norm_size);
} else if (static_cast<int64_t>(scale_size) != norm_size || (bias_data && static_cast<int64_t>(bias_size) != norm_size)) {
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckBroadcast(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, broadcast_param));
}
LayerNormParams params;
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params));

IAllocatorUniquePtr<float> scale_fp32;
IAllocatorUniquePtr<float> bias_fp32;
if constexpr (std::is_same_v<T, MLFloat16>) {
if (prepacked_scale_fp32_data_ == nullptr) {
const size_t num_elems = static_cast<size_t>(scale_size);
const size_t num_elems = static_cast<size_t>(params.scale_size);
scale_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems);
}
if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
const size_t num_elems = static_cast<size_t>(bias_size);
const size_t num_elems = static_cast<size_t>(params.bias_size);
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
}
}

concurrency::ThreadPool::TryBatchParallelFor(
thread_pool, static_cast<int32_t>(norm_count),
thread_pool, static_cast<int32_t>(params.num_rows),
[&](ptrdiff_t task_idx) {
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size, broadcast_param,
ComputeJob(X_data, scale_data, bias_data, task_idx, params.norm_size, params.broadcast_param,
prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
Expand Down
21 changes: 7 additions & 14 deletions onnxruntime/core/providers/cuda/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,20 +48,11 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
auto x_num_dims = x_shape.NumDimensions();
const int64_t axis = HandleNegativeAxis(axis_, x_num_dims);

int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
int n2 = gsl::narrow<int>(x_shape.SizeFromDimension(axis));

const TensorShape& scale_shape = scale->Shape();

const TensorShape& bias_shape = bias_data ? bias->Shape() : TensorShape();

int64_t broadcast_param = 0;
if (n2 <= 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, n2);
} else if (scale_shape.Size() != n2 || (bias_data && bias_shape.Size() != n2)) {
// Check if scale and bias can be broadcasted to X (only limited cases are supported).
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckBroadcast(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, broadcast_param));
}
LayerNormParams params;
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params));

// Outputs
Tensor* Y = ctx->Output(0, x_shape);
Expand Down Expand Up @@ -97,9 +88,11 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
return Status::OK();
}

HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data,
X_data, n1, n2, epsilon_, scale_data, bias_data,
gsl::narrow_cast<int>(broadcast_param));
HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(
GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data, X_data,
onnxruntime::narrow<int>(params.num_rows), onnxruntime::narrow<int>(params.norm_size), epsilon_,
scale_data, bias_data,
onnxruntime::narrow<int>(params.broadcast_param));
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
Expand Down
14 changes: 3 additions & 11 deletions onnxruntime/core/providers/cuda/nn/layer_norm_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
/* Modifications Copyright (c) Microsoft. */

#include "core/providers/cuda/cu_inc/common.cuh"

#include "layer_norm_impl.h"
#include "core/providers/cpu/nn/layer_norm_helper.h"

namespace onnxruntime {
namespace cuda {
Expand Down Expand Up @@ -355,16 +355,8 @@ __global__ void cuApplyLayerNorm(
T* skip_input_bias_add_ovals = (skip_input_bias_add_output != nullptr) ? skip_input_bias_add_output + offset : nullptr;
U c_inv_std_dev = rsqrt(sigma2 + epsilon);

// When X shape is (B, S, ...), and i1 is in the range of [0, B * S).
// We support scale and bias shape like below:
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.
// Here we compute the offset of gamma and beta (assuming they have same shape) to support broadcasting.
int gamma_beta_offset = (broadcast_param == 0)
? 0
: n2 * (broadcast_param > 0 ? (i1 / broadcast_param) : (i1 % (-broadcast_param)));
// Compute the offset of gamma and beta to support broadcasting.
int gamma_beta_offset = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, i1, n2);

const int numx = blockDim.x * blockDim.y;
const int thrx = threadIdx.x + threadIdx.y * blockDim.x;
Expand Down

0 comments on commit ee01f92

Please sign in to comment.