diff --git a/csrc/layernorm_kernels.cu b/csrc/layernorm_kernels.cu index 06384abfb635b..7f77447e282c3 100644 --- a/csrc/layernorm_kernels.cu +++ b/csrc/layernorm_kernels.cu @@ -59,6 +59,8 @@ struct _halfVec { __device__ inline _halfVec& operator+=(const _halfVec& other) { if constexpr (width % 2 == 0) { + // Hint to compiler to use packed operations + #pragma unroll for (int i = 0; i < width; i += 2) { __half2 z = __half2{data[i], data[i+1]}; z += __half2{other.data[i], other.data[i+1]}; @@ -75,6 +77,8 @@ struct _halfVec { __device__ inline _halfVec& operator*=(const _halfVec& other) { if constexpr (width % 2 == 0) { + // Hint to compiler to use packed operations + #pragma unroll for (int i = 0; i < width; i += 2) { __half2 z = __half2{data[i], data[i+1]}; z *= __half2{other.data[i], other.data[i+1]}; @@ -91,6 +95,7 @@ struct _halfVec { __device__ inline _halfVec& operator*=(const float scale) { if constexpr (width % 2 == 0) { + // Hint to compiler to use packed operations #pragma unroll for (int i = 0; i < width; i += 2) { float2 zf = __half22float2(__half2{data[i], data[i+1]}); @@ -109,6 +114,7 @@ struct _halfVec { __device__ inline float sum_squares() const { float result = 0.0f; if constexpr (width % 2 == 0) { + // Hint to compiler to use packed operations #pragma unroll for (int i = 0; i < width; i += 2) { float2 z = __half22float2(__half2{data[i], data[i+1]});