Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LayerNormalization broadcast (limited support for axis=2) #23297

Merged
merged 11 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
(double)epsilon_, // epsilon
reinterpret_cast<const CudaT*>(gamma->Data<T>()), // gamma
(beta != nullptr) ? reinterpret_cast<const CudaT*>(beta->Data<T>()) : nullptr, // beta
0, // no broadcast for gamma/beta
reinterpret_cast<const CudaT*>(skip->Data<T>()), // skip or residual to add
(bias != nullptr) ? reinterpret_cast<const CudaT*>(bias->Data<T>()) : nullptr, // bias to add
sum_output != nullptr ? reinterpret_cast<CudaT*>(sum_output->MutableData<T>()) : nullptr);
Expand Down
116 changes: 116 additions & 0 deletions onnxruntime/core/providers/cpu/nn/layer_norm_helper.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/framework/tensor_shape.h"
#include "core/common/status.h"
#include "core/common/narrow.h"

namespace onnxruntime {

constexpr const char* kLayerNormInputShapeMismatchError =
"Size of scale and bias (if provided) must match X.shape[axis:], "
"or scale and bias (with same shape) can be broadcasted to X when axis is 2.";

constexpr const char* kLayerNormInvalidSize = "Size of X.shape[axis:] must be larger than 1, got ";

constexpr int64_t kLayerNormInvalidInput = -1;

struct LayerNormParams {
int64_t num_rows;
int64_t norm_size; // size per row
int64_t scale_size;
int64_t bias_size;
int64_t broadcast_param;
};

// We support broadcasting for axis=2, where the first two dimensions are rows, and the rest are columns.
// When X shape is (B, S, ...), and x_row (index of one row in X) is in the range of [0, B * S).
// We support scale and bias shape like below:
// When scale and bias shape is (1, 1, ...) or (...), value of broadcast_param is 0.
// When scale and bias shape is (B, 1, ...), value of broadcast_param is S.
// When scale and bias shape is (B, S, ...), value of broadcast_param is 1.
// When scale and bias shape is (1, S, ...), value of broadcast_param is -S.

// Below is a macro to compute the offset for scale and bias data for a row of X.
#ifndef LAYER_NORM_SCALE_BIAS_OFFSET
#define LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, x_row, norm_size) \
((broadcast_param == 0) ? 0 \
: norm_size * (broadcast_param > 0 ? x_row / broadcast_param : x_row % (-broadcast_param)))
#endif

class LayerNormHelper {
public:
static Status CheckInputs(const TensorShape& x_shape,
const TensorShape& scale_shape,
const TensorShape& bias_shape,
bool has_bias,
int64_t axis,
LayerNormParams& params) {
params.num_rows = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
params.norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));
params.scale_size = scale_shape.Size();
params.bias_size = bias_shape.Size();
params.broadcast_param = 0;

if (params.norm_size <= 1) {
params.broadcast_param = kLayerNormInvalidInput;
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, kLayerNormInvalidSize, params.norm_size);
} else if (params.scale_size != params.norm_size || (has_bias && params.bias_size != params.scale_size)) {
params.broadcast_param = GetBroadcastParam(x_shape, scale_shape, has_bias ? &bias_shape : nullptr, axis);
if (params.broadcast_param == kLayerNormInvalidInput) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
kLayerNormInputShapeMismatchError,
" X.shape=", x_shape,
" scale.shape=", scale_shape,
" bias.shape=", bias_shape,
" and axis=", axis);
}
}
return Status::OK();
}

private:
static int64_t GetBroadcastParam(const TensorShape& x_shape,
const TensorShape& scale_shape,
const TensorShape* bias_shape,
int64_t axis) {
// Note that when size of scale and bias is norm_size, it won't enter this function (see CheckInputs).

// X shape is (B, S, ...)
if (axis == 2 &&
x_shape.NumDimensions() >= 3 &&
x_shape.NumDimensions() == scale_shape.NumDimensions() &&
(bias_shape == nullptr || *bias_shape == scale_shape)) {
for (size_t i = 2; i < x_shape.NumDimensions(); ++i) {
if (x_shape.GetDims()[i] != scale_shape.GetDims()[i]) {
// scale cannot be broadcasted to X. It is invalid input.
return kLayerNormInvalidInput;
}
}

if (x_shape.GetDims()[0] == scale_shape.GetDims()[0]) {
// scale and bias shape is (B, S, ...).
if (x_shape.GetDims()[1] == scale_shape.GetDims()[1]) {
return 1;
}

// scale and bias shape is (B, 1, ...), returns S
if (scale_shape.GetDims()[1] == 1) {
return x_shape.GetDims()[1];
}
} else if (scale_shape.GetDims()[0] == 1) {
// scale and bias shape is (1, S, ...), returns -S
if (x_shape.GetDims()[1] == scale_shape.GetDims()[1]) {
return -(x_shape.GetDims()[1]);
}
}
}

// Other cases that are not supported.
return kLayerNormInvalidInput;
}
};

} // namespace onnxruntime
62 changes: 31 additions & 31 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "layer_norm_impl.h"
#include "layer_norm_helper.h"

Check warning on line 5 in onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/providers/cpu/nn/layer_norm_impl.cc:5: Include the directory when naming header files [build/include_subdir] [4]

#include "core/common/safeint.h"
#include "core/framework/tensor.h"
Expand All @@ -24,6 +25,7 @@
const T* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
const int64_t broadcast_param,
const float* scale_float_ptr,
const float* bias_float_ptr,
float epsilon,
Expand Down Expand Up @@ -55,13 +57,16 @@
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

for (int64_t h = 0; h < norm_size; h++) {
// Compute the offset of gamma and beta to support broadcasting.
int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size);

for (int64_t h = 0; h < norm_size; h++, i++) {
if (simplified) {
p_output[h] = p_output[h] / mean_square * scale_data[h];
p_output[h] = p_output[h] / mean_square * scale_data[i];
} else if (nullptr == bias_data) {
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h];
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i];
} else {
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[h] + bias_data[h];
p_output[h] = (p_output[h] - mean) / mean_square * scale_data[i] + bias_data[i];
}
}

Expand All @@ -82,6 +87,7 @@
const MLFloat16* bias_data,
const ptrdiff_t task_idx,
const int64_t norm_size,
const int64_t broadcast_param,
const float* scale_float_ptr,
const float* bias_float_ptr,
float epsilon,
Expand Down Expand Up @@ -120,13 +126,16 @@
mean_square = sqrt(mean_square / norm_size - mean * mean + epsilon);
}

for (size_t h = 0; h < num_elems; h++) {
// Compute the offset of gamma and beta to support broadcasting.
int64_t i = LAYER_NORM_SCALE_BIAS_OFFSET(broadcast_param, task_idx, norm_size);

for (size_t h = 0; h < num_elems; h++, i++) {
if (simplified) {
output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[h];
output_float_ptr[h] = output_float_ptr[h] / mean_square * scale_float_ptr[i];
} else if (nullptr == bias_float_ptr) {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h];
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i];
} else {
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[h] + bias_float_ptr[h];
output_float_ptr[h] = (output_float_ptr[h] - mean) / mean_square * scale_float_ptr[i] + bias_float_ptr[i];
}
}

Expand Down Expand Up @@ -161,9 +170,7 @@
simplified_{simplified},
contrib_op_{contrib_op},
prepacked_scale_fp32_data_(nullptr),
prepacked_scale_fp32_size_(0),
prepacked_bias_fp32_data_(nullptr),
prepacked_bias_fp32_size_(0) {
prepacked_bias_fp32_data_(nullptr) {
ORT_ENFORCE(op_kernel_info.GetAttr("axis", &axis_).IsOK());
ORT_ENFORCE(op_kernel_info.GetAttr<float>("epsilon", &epsilon_).IsOK());
}
Expand All @@ -179,8 +186,8 @@
const T* bias_data = (simplified || nullptr == bias) ? nullptr : bias->Data<T>();

const TensorShape& x_shape = X->Shape();
size_t scale_size = scale ? static_cast<size_t>(scale->Shape().Size()) : prepacked_scale_fp32_size_;
size_t bias_size = bias ? static_cast<size_t>(bias->Shape().Size()) : prepacked_bias_fp32_size_;
const TensorShape& scale_shape = scale ? scale->Shape() : prepacked_scale_fp32_shape_;
const TensorShape& bias_shape = bias ? bias->Shape() : prepacked_bias_fp32_shape_;
Tensor* Y = p_ctx->Output(0, x_shape);
T* Y_data = Y->MutableData<T>();

Expand Down Expand Up @@ -215,7 +222,7 @@

AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(p_ctx->GetTempSpaceAllocator(&alloc));
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_size, bias_data, bias_size, Y_data, mean_data,
return ComputeWithoutContext<T, U>(X_data, x_shape, scale_data, scale_shape, bias_data, bias_shape, Y_data, mean_data,
inv_std_dev_data, thread_pool, axis, epsilon, simplified, alloc);
}

Expand All @@ -234,10 +241,10 @@

is_packed = false;
if (input_idx == 1) { // scale
prepacked_scale_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
prepacked_scale_fp32_shape_ = tensor.Shape();
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_scale_fp32_data_, is_packed);
} else if (input_idx == 2) { // bias
prepacked_bias_fp32_size_ = static_cast<size_t>(tensor.Shape().Size());
prepacked_bias_fp32_shape_ = tensor.Shape();
ConvertMLFloat16ToFloatIfNeeded(tensor, alloc, prepacked_bias_fp32_data_, is_packed);
}

Expand All @@ -249,9 +256,9 @@
const T* X_data,
const TensorShape& x_shape,
const T* scale_data,
size_t scale_size,
const TensorShape& scale_shape,
const T* bias_data,
size_t bias_size,
const TensorShape& bias_shape,
T* Y_data,
U* mean_data,
U* inv_std_dev_data,
Expand All @@ -260,35 +267,28 @@
float epsilon,
bool simplified,
AllocatorPtr alloc) const {
int64_t norm_count = x_shape.SizeToDimension(onnxruntime::narrow<size_t>(axis));
int64_t norm_size = x_shape.SizeFromDimension(onnxruntime::narrow<size_t>(axis));

if (static_cast<int64_t>(scale_size) != norm_size || (bias_data && static_cast<int64_t>(bias_size) != norm_size)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", norm_size,
". Size of scale and bias (if provided) must match this. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}
LayerNormParams params;
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params));

IAllocatorUniquePtr<float> scale_fp32;
IAllocatorUniquePtr<float> bias_fp32;
if constexpr (std::is_same_v<T, MLFloat16>) {
if (prepacked_scale_fp32_data_ == nullptr) {
const size_t num_elems = static_cast<size_t>(norm_size);
const size_t num_elems = static_cast<size_t>(params.scale_size);
scale_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(scale_data, scale_fp32.get(), num_elems);
}
if (prepacked_bias_fp32_data_ == nullptr && bias_data) {
const size_t num_elems = static_cast<size_t>(norm_size);
const size_t num_elems = static_cast<size_t>(params.bias_size);
bias_fp32 = IAllocator::MakeUniquePtr<float>(alloc, num_elems);
MlasConvertHalfToFloatBuffer(bias_data, bias_fp32.get(), num_elems);
}
}

concurrency::ThreadPool::TryBatchParallelFor(
thread_pool, static_cast<int32_t>(norm_count),
thread_pool, static_cast<int32_t>(params.num_rows),
[&](ptrdiff_t task_idx) {
ComputeJob(X_data, scale_data, bias_data, task_idx, norm_size,
ComputeJob(X_data, scale_data, bias_data, task_idx, params.norm_size, params.broadcast_param,
prepacked_scale_fp32_data_ ? prepacked_scale_fp32_data_.get() : scale_fp32.get(),
prepacked_bias_fp32_data_ ? prepacked_bias_fp32_data_.get() : bias_fp32.get(),
epsilon, simplified, Y_data, mean_data, inv_std_dev_data, alloc);
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/core/providers/cpu/nn/layer_norm_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ class LayerNormImpl : public OpKernel {
const T* X_data,
const TensorShape& x_shape,
const T* scale_data,
size_t scale_size,
const TensorShape& scale_shape,
const T* bias_data,
size_t bias_size,
const TensorShape& bias_shape,
T* Y_data,
U* mean_data,
U* inv_std_dev,
Expand Down Expand Up @@ -64,9 +64,9 @@ class LayerNormImpl : public OpKernel {
const bool simplified_;
const bool contrib_op_;
IAllocatorUniquePtr<float> prepacked_scale_fp32_data_;
size_t prepacked_scale_fp32_size_;
TensorShape prepacked_scale_fp32_shape_;
IAllocatorUniquePtr<float> prepacked_bias_fp32_data_;
size_t prepacked_bias_fp32_size_;
TensorShape prepacked_bias_fp32_shape_;
};

} // namespace onnxruntime
32 changes: 15 additions & 17 deletions onnxruntime/core/providers/cuda/nn/layer_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/cuda/nn/layer_norm.h"
#include "core/providers/cuda/nn/layer_norm_impl.h"
#include "core/providers/cpu/nn/layer_norm_helper.h"
#include "core/providers/cuda/cuda_common.h"

namespace onnxruntime {
Expand Down Expand Up @@ -44,28 +45,22 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
auto bias_data = (simplified || (nullptr == bias)) ? nullptr : reinterpret_cast<const CudaV*>(bias->Data<V>());

const TensorShape& x_shape = X->Shape();
const int64_t axis = HandleNegativeAxis(axis_, x_shape.NumDimensions());

int n1 = gsl::narrow<int>(x_shape.SizeToDimension(axis));
int n2 = gsl::narrow<int>(x_shape.SizeFromDimension(axis));

const auto scale_size = scale->Shape().Size();
const auto bias_size = (bias_data) ? bias->Shape().Size() : 0;
if (n2 == 1 || scale_size != n2 || (bias_data && bias_size != n2)) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Size of X.shape()[axis:] == ", n2,
". Size of scale and bias (if provided) must match this "
"and the size must not be 1. Got scale size of ",
scale_size, " and bias size of ", bias_size);
}
auto x_num_dims = x_shape.NumDimensions();
const int64_t axis = HandleNegativeAxis(axis_, x_num_dims);

const TensorShape& scale_shape = scale->Shape();
const TensorShape& bias_shape = bias_data ? bias->Shape() : TensorShape();

LayerNormParams params;
ORT_RETURN_IF_ERROR(LayerNormHelper::CheckInputs(x_shape, scale_shape, bias_shape, bias_data != nullptr, axis, params));

// Outputs
Tensor* Y = ctx->Output(0, x_shape);
auto Y_data = reinterpret_cast<CudaV*>(Y->MutableData<V>());

// Mean and variance
std::vector<int64_t> mean_inv_std_var_dim;
for (int i = 0; i < static_cast<int>(x_shape.NumDimensions()); ++i) {
for (int i = 0; i < static_cast<int>(x_num_dims); ++i) {
if (i < axis) {
mean_inv_std_var_dim.emplace_back(x_shape.GetDims()[i]);
} else {
Expand Down Expand Up @@ -93,8 +88,11 @@ Status LayerNorm<T, U, V, simplified>::ComputeInternal(OpKernelContext* ctx) con
return Status::OK();
}

HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data,
X_data, n1, n2, epsilon_, scale_data, bias_data);
HostApplyLayerNorm<CudaT, CudaU, CudaV, simplified>(
GetDeviceProp(), Stream(ctx), Y_data, mean_data, inv_var_data, X_data,
onnxruntime::narrow<int>(params.num_rows), onnxruntime::narrow<int>(params.norm_size), epsilon_,
scale_data, bias_data,
onnxruntime::narrow<int>(params.broadcast_param));
CUDA_RETURN_IF_ERROR(cudaGetLastError());
return Status::OK();
}
Expand Down
Loading
Loading