Skip to content

Commit

Permalink
fused residual add layernorm
Browse files Browse the repository at this point in the history
  • Loading branch information
root committed Nov 16, 2023
1 parent 767c8e4 commit 7362eca
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 13 deletions.
12 changes: 12 additions & 0 deletions csrc/layernorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,21 @@ void rms_norm(
torch::Tensor& weight,
float epsilon);

void res_add_rms_norm(
torch::Tensor& out,
torch::Tensor& out_hidden,
torch::Tensor& residual,
torch::Tensor& hidden,
torch::Tensor& weight,
float epsilon);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"rms_norm",
&rms_norm,
"Apply Root Mean Square (RMS) Normalization to the input tensor.");
m.def(
"res_add_rms_norm",
&res_add_rms_norm,
"Res Add Apply Root Mean Square (RMS) Normalization to the input tensor.");
}
60 changes: 60 additions & 0 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,35 @@ __global__ void rms_norm_kernel(
}
}

template<typename scalar_t>
__global__ void res_add_rms_norm_kernel(
scalar_t* __restrict__ out, // [num_tokens, hidden_size]
scalar_t* __restrict__ out_hidden, // [num_tokens, hidden_size]
const scalar_t* __restrict__ residual, // [num_tokens, hidden_size]
const scalar_t* __restrict__ hidden, // [num_tokens, hidden_size]
const scalar_t* __restrict__ weight, // [hidden_size]
const float epsilon,
const int num_tokens,
const int hidden_size) {
__shared__ float s_variance;
float variance = 0.0f;

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
const float x = (float) (hidden[blockIdx.x * hidden_size + idx] + residual[blockIdx.x * hidden_size + idx]);
variance += x * x;
out_hidden[blockIdx.x * hidden_size + idx] = (scalar_t) x;
}
variance = blockReduceSum<float>(variance);
if (threadIdx.x == 0) {
s_variance = rsqrtf(variance / hidden_size + epsilon);
}
__syncthreads();

for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
float x = (float) out_hidden[blockIdx.x * hidden_size + idx];
out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
}
}
} // namespace vllm

void rms_norm(
Expand Down Expand Up @@ -61,3 +90,34 @@ void rms_norm(
hidden_size);
});
}

void res_add_rms_norm(
torch::Tensor& out, // [num_tokens, hidden_size]
torch::Tensor& out_hidden, // [num_tokens, hidden_size]
torch::Tensor& residual, // [num_tokens, hidden_size]
torch::Tensor& hidden, // [num_tokens, hidden_size]
torch::Tensor& weight, // [hidden_size]
float epsilon) {
int num_tokens = hidden.size(0);
int hidden_size = hidden.size(1);

dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND2(
at::ScalarType::Half,
at::ScalarType::BFloat16,
hidden.scalar_type(),
"rms_norm_kernel",
[&] {
vllm::res_add_rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
out.data_ptr<scalar_t>(),
out_hidden.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
hidden.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
epsilon,
num_tokens,
hidden_size);
});
}
63 changes: 63 additions & 0 deletions tests/kernels/test_res_add_layernorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.nn as nn

from vllm import layernorm_ops


class RefRMSNorm(nn.Module):

def __init__(self, hidden_size, eps=1e-6):
super().__init__()
weight = torch.empty(hidden_size)
weight.uniform_(-1e-3, 1e-3)
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps

def forward(self, residual, hidden_states):
hidden_states = residual + hidden_states
saved = hidden_states
variance = hidden_states.to(torch.float32).pow(2).mean(-1,
keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance +
self.variance_epsilon)
if self.weight.dtype in [torch.half, torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states, saved


@torch.inference_mode()
def run_rms_norm(
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
) -> None:
residual = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
hidden = torch.randn(num_tokens, hidden_size, dtype=dtype, device='cuda')
ref = RefRMSNorm(hidden_size).to(dtype).cuda()

out = torch.empty_like(hidden)
out_res = torch.empty_like(hidden)
layernorm_ops.res_add_rms_norm(
out,
out_res,
residual,
hidden,
ref.weight.data,
ref.variance_epsilon,
)
ref_out, ref_res = ref(residual,hidden)
assert torch.allclose(out, ref_out, atol=1e-3, rtol=1e-5)
assert torch.allclose(out_res, ref_res, atol=1e-3, rtol=1e-5)


def test_rms_norm() -> None:
for dtype in [torch.half, torch.bfloat16, torch.float]:
for num_tokens in [1, 7, 128, 2048]:
for hidden_size in [13, 64, 1024, 5120, 8192]:
print(f'Testing RMS kernel with dtype={dtype}, num_tokens='
f'{num_tokens}, hidden_size={hidden_size}')
run_rms_norm(
num_tokens=num_tokens,
hidden_size=hidden_size,
dtype=dtype,
)
29 changes: 29 additions & 0 deletions vllm/model_executor/layers/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,32 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
self.variance_epsilon,
)
return out

class ResAddRMSNorm(nn.Module):
"""Root mean square normalization.
Computes x -> w * x / sqrt(E[x^2] + eps) where w is the learned weight.
Refer to https://arxiv.org/abs/1910.07467
"""

def __init__(
self,
hidden_size: int,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps

def forward(self, residual: torch.Tensor, hidden: torch.Tensor):
out = torch.empty_like(hidden)
out_res = torch.empty_like(hidden)
layernorm_ops.res_add_rms_norm(
out,
out_res,
residual,
hidden,
self.weight.data,
self.variance_epsilon,
)
return out, out_res
34 changes: 21 additions & 13 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

from vllm.model_executor.input_metadata import InputMetadata
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.layernorm import RMSNorm,ResAddRMSNorm
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
Expand Down Expand Up @@ -187,37 +187,41 @@ def __init__(self, config: LlamaConfig):
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act
)
self.input_layernorm = RMSNorm(config.hidden_size,
self.input_layernorm = ResAddRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
self.post_attention_layernorm = ResAddRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
prev_residual: torch.Tensor,
kv_cache: KVCache,
input_metadata: InputMetadata,
cache_event: Optional[torch.cuda.Event],
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
#hidden_states = prev_residual + hidden_states
#residual = hidden_states
#hidden_states = self.input_layernorm(hidden_states)
hidden_states, residual = self.input_layernorm(prev_residual,hidden_states)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
kv_cache=kv_cache,
input_metadata=input_metadata,
cache_event=cache_event,
)
hidden_states = residual + hidden_states
#hidden_states = residual + hidden_states

# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
#residual = hidden_states
#hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states,residual = self.post_attention_layernorm(residual,hidden_states)
hidden_states = self.mlp(hidden_states,input_metadata.num_generation_tokens)
hidden_states = residual + hidden_states
return hidden_states
#hidden_states = residual + hidden_states
return hidden_states,residual


class LlamaModel(nn.Module):
Expand All @@ -234,7 +238,7 @@ def __init__(self, config: LlamaConfig):
self.layers = nn.ModuleList([
LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.norm = ResAddRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
Expand All @@ -245,20 +249,24 @@ def forward(
cache_events: Optional[List[torch.cuda.Event]],
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = torch.zeros_like(hidden_states)
for i in range(len(self.layers)):
if cache_events is None:
cache_event = None
else:
cache_event = cache_events[i]
layer = self.layers[i]
hidden_states = layer(
hidden_states, residual = layer(
positions,
hidden_states,
residual,
kv_caches[i],
input_metadata,
cache_event,
)
hidden_states = self.norm(hidden_states)
#hidden_states = residual + hidden_states
#hidden_states = self.norm(hidden_states)
hidden_states, residual = self.norm(residual,hidden_states)
return hidden_states


Expand Down

0 comments on commit 7362eca

Please sign in to comment.