Skip to content

Commit

Permalink
WIP - factor out mamba
Browse files Browse the repository at this point in the history
Signed-off-by: mzusman <[email protected]>
  • Loading branch information
mzusman committed Nov 3, 2024
1 parent 2ae25f7 commit ee13d1f
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 180 deletions.
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/mamba/ops/mamba_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,7 @@ def selective_scan_fn(
delta_softplus=False,
query_start_loc=None,
cache_indices=None,
has_initial_state=None) -> Tuple[torch.Tensor, torch.Tensor]:
has_initial_state=None) -> torch.Tensor:
"""
u: (dim, total_length) for varlen or (batch, dim, seqlen)
delta: (dim, total_length) for varlen or (batch, dim, seqlen)
Expand Down
179 changes: 0 additions & 179 deletions vllm/model_executor/models/jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,185 +51,6 @@ class MambaCacheParams:


# Adapted from transformers.models.mamba.modeling_mamba.MambaMixer
class JambaMambaMixer(nn.Module):
"""
Compute ∆, A, B, C, and D the state space parameters and compute
the `contextualized_states`. A, D are input independent
(see Mamba paper [1] Section 3.5.2 "Interpretation of A"
for why A isn't selective) ∆, B, C are input-dependent
(this is a key difference between Mamba and the linear time
invariant S4, and is why Mamba is called
**selective** state spaces)
"""

def __init__(self, config: JambaConfig, layer_idx):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.ssm_state_size = config.mamba_d_state
self.conv_kernel_size = config.mamba_d_conv
self.intermediate_size = config.mamba_expand * config.hidden_size
self.time_step_rank = config.mamba_dt_rank
self.use_conv_bias = config.mamba_conv_bias
self.use_bias = config.mamba_proj_bias
self.conv1d = ColumnParallelLinear(
input_size=self.conv_kernel_size,
output_size=self.intermediate_size,
bias=self.use_conv_bias,
)
# unsqueeze to fit conv1d weights shape into the linear weights shape.
# Can't do this in `weight_loader` since it already exists in
# `ColumnParallelLinear` and `set_weight_attrs`
# doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)

self.in_proj = MergedColumnParallelLinear(self.hidden_size,
[self.intermediate_size] * 2,
bias=self.use_bias)
# selective projection used to make dt, B and C input dependent
self.x_proj = RowParallelLinear(
self.intermediate_size,
self.time_step_rank + self.ssm_state_size * 2,
bias=False,
)
# time step projection (discretization) -
# In the forward we need to apply dt_proj without the bias,
# as the bias is added in the selective scan kernel.
self.dt_proj = ColumnParallelLinear(self.time_step_rank,
self.intermediate_size,
bias=True,
skip_bias_add=True)

def weight_loader(param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
tp_size = get_tensor_model_parallel_world_size()
param.data.copy_(
loaded_weight.data.split(loaded_weight.shape[0] // tp_size,
dim=0)[tp_rank])

def A_weight_loader(param: Parameter, loaded_weight: torch.Tensor):
weight_loader(param, -torch.exp(loaded_weight.float()))

tp_size = get_tensor_model_parallel_world_size()
self.A = nn.Parameter(
torch.empty(
self.intermediate_size // tp_size,
self.ssm_state_size,
dtype=torch.float32,
))
self.D = nn.Parameter(torch.ones(self.intermediate_size // tp_size))

set_weight_attrs(self.D, {"weight_loader": weight_loader})
set_weight_attrs(self.A, {"weight_loader": A_weight_loader})

self.out_proj = RowParallelLinear(
self.intermediate_size,
self.hidden_size,
bias=self.use_bias,
input_is_parallel=True,
)
self.activation = config.hidden_act

self.dt_layernorm = RMSNorm(self.time_step_rank,
eps=config.rms_norm_eps)
self.b_layernorm = RMSNorm(self.ssm_state_size,
eps=config.rms_norm_eps)
self.c_layernorm = RMSNorm(self.ssm_state_size,
eps=config.rms_norm_eps)

def forward(self, hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata, conv_state: torch.Tensor,
ssm_state: torch.Tensor):

# 1. Gated MLP's linear projection
projected_states = self.in_proj(hidden_states)[0].transpose(-2, -1)
hidden_states, gate = projected_states.chunk(2, dim=-2)

# 2. Convolution sequence transformation
conv_weights = self.conv1d.weight.view(self.conv1d.weight.size(0),
self.conv1d.weight.size(2))

if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
hidden_states = causal_conv1d_fn(
hidden_states,
conv_weights,
self.conv1d.bias,
activation=self.activation,
conv_states=conv_state,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
hidden_states = causal_conv1d_update(
hidden_states.transpose(0, 1),
conv_state,
conv_weights,
self.conv1d.bias,
self.activation,
)
hidden_states = hidden_states.transpose(0, 1)

# 3. State Space Model sequence transformation
# 3.a. input varying initialization of time_step, B and C
ssm_parameters = self.x_proj(hidden_states.transpose(-2, -1))[0]

time_step, B, C = torch.split(
ssm_parameters,
[self.time_step_rank, self.ssm_state_size, self.ssm_state_size],
dim=-1,
)
time_step = self.dt_layernorm(time_step.contiguous())
B = self.b_layernorm(B.contiguous())
C = self.c_layernorm(C.contiguous())

discrete_time_step = self.dt_proj(time_step)[0].transpose(-2, -1)
# 3.c perform the recurrence y ← SSM(A, B, C)(x)
time_proj_bias = (self.dt_proj.bias.float() if hasattr(
self.dt_proj, "bias") else None)

if attn_metadata.query_start_loc is not None \
and attn_metadata.context_lens_tensor is not None:
scan_outputs = selective_scan_fn(
hidden_states,
ssm_state,
discrete_time_step,
self.A,
B.transpose(-2, -1),
C.transpose(-2, -1),
self.D.float(),
gate,
time_proj_bias,
delta_softplus=True,
has_initial_state=attn_metadata.context_lens_tensor > 0,
query_start_loc=attn_metadata.query_start_loc)
else:
scan_outputs = selective_state_update(
ssm_state,
hidden_states.transpose(0, 1),
discrete_time_step.transpose(0, 1),
self.A,
B,
C,
self.D,
gate.transpose(0, 1),
time_proj_bias,
dt_softplus=True,
)
scan_outputs = scan_outputs.transpose(0, 1)

# 4. Final linear projection
contextualized_states = self.out_proj(scan_outputs.transpose(-2,
-1))[0]
return contextualized_states


class JambaMoE(nn.Module):

def __init__(self,
Expand Down

0 comments on commit ee13d1f

Please sign in to comment.