Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model] factoring out MambaMixer out of Jamba #8993

Merged
merged 5 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 217 additions & 0 deletions vllm/model_executor/layers/mamba/mamba_mixer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,217 @@
import torch
from torch import nn
from torch.nn.parameter import Parameter

from vllm.attention.backends.abstract import AttentionMetadata
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.ops.causal_conv1d import (
causal_conv1d_fn, causal_conv1d_update)
from vllm.model_executor.layers.mamba.ops.mamba_ssm import (
selective_scan_fn, selective_state_update)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs


# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
@CustomOp.register("mamba_mixer")
class MambaMixer(CustomOp):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""

def __init__(self,
hidden_size: int,
ssm_state_size: int,
conv_kernel_size: int,
intermediate_size: int,
time_step_rank: int,
use_conv_bias: bool,
use_bias: bool,
use_rms_norm: bool,
rms_norm_eps: float = 1e-5,
activation="silu"):
super().__init__()
self.time_step_rank = time_step_rank
self.ssm_state_size = ssm_state_size
self.use_rms_norm = use_rms_norm
self.activation = activation

self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size,
output_size=intermediate_size,
bias=use_conv_bias,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

self.in_proj = MergedColumnParallelLinear(hidden_size,
[intermediate_size] * 2,
bias=use_bias)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
intermediate_size,
time_step_rank + ssm_state_size * 2,
bias=False,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(time_step_rank,
intermediate_size,
bias=True,
skip_bias_add=True)

def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
param.data.copy_(
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
dim=0)[tp_rank])

def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
weight_loader(param, -torch.exp(loaded_weight.float()))

tp_size = get_tensor_model_parallel_world_size()
self.A = nn.Parameter(
torch.empty(
intermediate_size // tp_size,
ssm_state_size,
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(intermediate_size // tp_size))

set_weight_attrs(self.D, {"weight_loader": weight_loader})
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})

self.out_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=use_bias,
input_is_parallel=True,
)

self.dt_layernorm = RMSNorm(time_step_rank,
eps=rms_norm_eps) if use_rms_norm else None

self.b_layernorm = RMSNorm(ssm_state_size,
eps=rms_norm_eps) if use_rms_norm else None

self.c_layernorm = RMSNorm(ssm_state_size,
eps=rms_norm_eps) if use_rms_norm else None

def forward_native(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
conv_state: torch.Tensor, ssm_state: torch.Tensor):
pass

def forward_cuda(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
mamba_cache_params: MambaCacheParams):

# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))

if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=mamba_cache_params.conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
cache_indices=mamba_cache_params.state_indices_tensor,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
mamba_cache_params.conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
conv_state_indices=mamba_cache_params.state_indices_tensor)
hidden_states = hidden_states.transpose(0, 1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]

time_step, B, C = torch.split(
ssm_parameters,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
if self.use_rms_norm:
assert self.dt_layernorm is not None
assert self.b_layernorm is not None
assert self.c_layernorm is not None
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())

discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)

if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
mamba_cache_params.ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
cache_indices=mamba_cache_params.state_indices_tensor,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
mamba_cache_params.ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
state_batch_indices=mamba_cache_params.state_indices_tensor)
scan_outputs = scan_outputs.transpose(0, 1)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states
Loading