Skip to content

Commit

Permalink
[Bugfix][SpecDecode] Adjust Eagle model architecture to align with in…
Browse files Browse the repository at this point in the history
…tended design (#11672)

Signed-off-by: Sungjae Lee <[email protected]>
  • Loading branch information
llsj14 authored Jan 11, 2025
1 parent 899136b commit 2118d05
Showing 1 changed file with 24 additions and 2 deletions.
26 changes: 24 additions & 2 deletions vllm/model_executor/models/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,30 @@
from .utils import maybe_prefix


class DummyInputLayerNorm(nn.Module):

def forward(self, x):
return x


class DummyOutputNorm(nn.Module):

def forward(self, x, residual):
if residual is None:
return x
else:
return x, residual


class EAGLE(nn.Module):
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
Reference implementation: https://github.com/SafeAILab/EAGLE
Differences from reference implementation:
1. In reference, LlamaDecoderLayer implementation doesn't have
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
but we do as HF implementation also does.
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427).
Following this approach, our implementation also disables
the input_layernorm for the first decoder layer.
2. We allow any decoder layer to be used in EAGLE whereas in reference
decoder layer is fixed to be LlamaDecoderLayer.
3. We have an optional token_map which reduces draft vocab to most
Expand All @@ -46,10 +62,16 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

self.model = model_cls(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))

self.fc = nn.Linear(config.model.hidden_size * 2,
config.model.hidden_size,
bias=getattr(self.config, "eagle_fc_bias", False))

# Modify layer normalization and residual connections as suggested
# in the EAGLE framework: https://github.com/SafeAILab/EAGLE
self.model.model.layers[0].input_layernorm = DummyInputLayerNorm()
self.model.model.norm = DummyOutputNorm()

self.orig_vocab_size = config.vocab_size
self.truncated_vocab_size = config.truncated_vocab_size
self.unpadded_vocab_size = self.truncated_vocab_size
Expand Down

0 comments on commit 2118d05

Please sign in to comment.