Skip to content

Commit

Permalink
[bugfix] fix weak ref in piecewise cudagraph and tractable test (vllm…
Browse files Browse the repository at this point in the history
…-project#10048)

Signed-off-by: youkaichao <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
youkaichao authored and tlrmchlsmth committed Nov 23, 2024
1 parent bf51158 commit 3060e3b
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 25 deletions.
111 changes: 101 additions & 10 deletions tests/compile/piecewise/test_toy_llama.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""
Test the piecewise compilation with a simple model, comparing the output
with and without the piecewise compilation.
This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
import os
from dataclasses import dataclass
Expand Down Expand Up @@ -49,6 +53,12 @@ class LlamaConfig:
mlp_size: int = 256
vocab_size: int = 128
num_layers: int = 2
init_value: float = 1.0
tractable_init: bool = False
random_seed: int = 0

def __post_init__(self):
assert self.mlp_size >= self.hidden_size


class LlamaMLP(nn.Module):
Expand All @@ -66,10 +76,23 @@ def __init__(self, config: LlamaConfig) -> None:
bias=False,
)

self.gate_up_projection.weight.data.fill_(0.0)
self.down_projection.weight.data.fill_(0.0)
if config.tractable_init:
nn.init.eye_(self.gate_up_projection.weight.data[:config.mlp_size])
nn.init.eye_(self.gate_up_projection.weight.data[config.mlp_size:])
nn.init.eye_(self.down_projection.weight.data)
else:
nn.init.xavier_normal_(self.gate_up_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(self.down_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)

def forward(self, x):
# for tractable_init and positive input, this is
# essentially an elementwise-square
x = self.gate_up_projection(x)
x = x[:, :x.size(1) // 2] * torch.nn.functional.relu(
x[:, x.size(1) // 2:])
Expand All @@ -84,21 +107,39 @@ def __init__(self, config: LlamaConfig) -> None:
self.qkv_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size * 3,
bias=False,
)

self.output_projection = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size,
bias=False,
)

self.qkv_projection.weight.data.fill_(0.0)
self.output_projection.weight.data.fill_(0.0)
if config.tractable_init:
nn.init.eye_(self.qkv_projection.weight.data[:config.hidden_size])
nn.init.eye_(self.qkv_projection.weight.data[config.hidden_size:2 *
config.hidden_size])
nn.init.eye_(self.qkv_projection.weight.data[2 *
config.hidden_size:])
nn.init.eye_(self.output_projection.weight.data)
else:
nn.init.xavier_normal_(self.qkv_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)
nn.init.xavier_normal_(self.output_projection.weight.data,
generator=torch.Generator().manual_seed(
config.random_seed),
gain=0.001)

def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# for tractable_init, this is:
# output = (hidden_states * 3 + positions * 2)
qkv = self.qkv_projection(hidden_states)
hidden_size = qkv.size(-1) // 3
q, k, v = qkv.split([hidden_size, hidden_size, hidden_size], dim=-1)
Expand Down Expand Up @@ -126,20 +167,29 @@ def forward(
hidden_states: torch.Tensor,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
For tractable computation:
- if residual is None, the outputs are:
- residual = (hidden_states + 1) * 3 + positions * 2 + hidden_states = hidden_states * 4 + positions * 2 + 3
- hidden_states = (residual + 1) ** 2
- if residual is not None, the outputs are:
- residual = (hidden_states + residual + 1) * 3 + positions * 2 + hidden_states + residual = (hidden_states + residual) * 4 + positions * 2 + 3
- hidden_states = (residual + 1) ** 2
""" # noqa
if residual is None:
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = hidden_states + 1
else:
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = hidden_states + 1

hidden_states = self.self_attention(positions=positions,
hidden_states=hidden_states)

hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = hidden_states / 2
hidden_states = hidden_states + 1
hidden_states = self.mlp(hidden_states)

return hidden_states, residual
Expand All @@ -156,7 +206,8 @@ def __init__(self, config: LlamaConfig) -> None:
self.layers = nn.ModuleList(
[LlamaDecoderLayer(config) for _ in range(config.num_layers)])

self.embedding_tokens.weight.data.fill_(0.0)
# this is the initial value of the hidden states
self.embedding_tokens.weight.data.fill_(config.init_value)

def forward(
self,
Expand All @@ -170,6 +221,28 @@ def forward(
return hidden_states


def tractable_computation(input_ids: torch.Tensor,
positions: torch.Tensor,
config: LlamaConfig,
init_value: float = 1.0) -> torch.Tensor:
hidden_states = torch.ones(input_ids.size(0),
config.hidden_size,
device=input_ids.device,
dtype=input_ids.dtype) * init_value

# first layer
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2

# following layers
for _ in range(config.num_layers - 1):
hidden_states = hidden_states + residual
residual = hidden_states * 4 + positions.unsqueeze(1) * 2 + 3
hidden_states = (residual + 1)**2

return hidden_states


@torch.inference_mode
def run_model(llama_config,
use_compile: bool,
Expand Down Expand Up @@ -213,7 +286,15 @@ def run_model(llama_config,
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
set_compilation_config(None)

return output.cpu()
output = output.cpu()

if llama_config.tractable_init:
expected_output = tractable_computation(input_ids[:2], positions[:2],
llama_config).cpu()

assert torch.allclose(output, expected_output)
else:
return output.cpu()


def test_toy_llama():
Expand All @@ -222,7 +303,13 @@ def test_toy_llama():
llama_config = LlamaConfig(hidden_size=128,
mlp_size=256,
vocab_size=128,
num_layers=2)
num_layers=12)

tractable_config = LlamaConfig(hidden_size=128,
mlp_size=256,
vocab_size=128,
num_layers=2,
tractable_init=True)

outputs = []
with compilation_counter.expect(
Expand All @@ -233,6 +320,8 @@ def test_toy_llama():
num_cudagraph_caputured=0,
):
outputs.append(run_model(llama_config, use_compile=False))
run_model(tractable_config, use_compile=False)

with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
num_piecewise_graphs_seen=1,
Expand All @@ -242,6 +331,7 @@ def test_toy_llama():
2, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
):
outputs.append(run_model(llama_config, use_compile=True))
run_model(tractable_config, use_compile=True)

with compilation_counter.expect(
num_graphs_seen=1, # one graph for the model
Expand All @@ -257,6 +347,7 @@ def test_toy_llama():
):
outputs.append(
run_model(llama_config, use_compile=True, split_attn=True))
run_model(tractable_config, use_compile=True, split_attn=True)

for i in range(1, len(outputs)):
assert torch.allclose(outputs[0], outputs[i])
Expand Down
Loading

0 comments on commit 3060e3b

Please sign in to comment.