Skip to content

Commit

Permalink
Add a few missing unroll directives
Browse files Browse the repository at this point in the history
  • Loading branch information
mawong-amd committed Mar 25, 2024
1 parent 20f8bd1 commit a7164ca
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ struct _halfVec {

__device__ inline _halfVec& operator+=(const _halfVec<width>& other) {
if constexpr (width % 2 == 0) {
// Hint to compiler to use packed operations
#pragma unroll
for (int i = 0; i < width; i += 2) {
__half2 z = __half2{data[i], data[i+1]};
z += __half2{other.data[i], other.data[i+1]};
Expand All @@ -75,6 +77,8 @@ struct _halfVec {

__device__ inline _halfVec& operator*=(const _halfVec<width>& other) {
if constexpr (width % 2 == 0) {
// Hint to compiler to use packed operations
#pragma unroll
for (int i = 0; i < width; i += 2) {
__half2 z = __half2{data[i], data[i+1]};
z *= __half2{other.data[i], other.data[i+1]};
Expand All @@ -91,6 +95,7 @@ struct _halfVec {

__device__ inline _halfVec& operator*=(const float scale) {
if constexpr (width % 2 == 0) {
// Hint to compiler to use packed operations
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 zf = __half22float2(__half2{data[i], data[i+1]});
Expand All @@ -109,6 +114,7 @@ struct _halfVec {
__device__ inline float sum_squares() const {
float result = 0.0f;
if constexpr (width % 2 == 0) {
// Hint to compiler to use packed operations
#pragma unroll
for (int i = 0; i < width; i += 2) {
float2 z = __half22float2(__half2{data[i], data[i+1]});
Expand Down

0 comments on commit a7164ca

Please sign in to comment.