Skip to content

Commit

Permalink
[Bugfix][Mamba] Fix Multistep on Mamba-like models (#10705)
Browse files Browse the repository at this point in the history
Signed-off-by: mzusman <[email protected]>
  • Loading branch information
mzusman authored Nov 27, 2024
1 parent b98c62b commit 197b448
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 4 deletions.
38 changes: 38 additions & 0 deletions tests/models/decoder_only/language/test_jamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,44 @@ def test_state_cleanup(
"could be related to finished_requests_ids")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_multistep(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
# This test is verifying that multistep works correctly
#on mamba-like models
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
max_tokens: int, example_prompts) -> None:
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
vllm_outputs_multistep = vllm_model.generate_greedy(
example_prompts, max_tokens)

with vllm_runner(model, num_scheduler_steps=1,
max_num_seqs=2) as vllm_model:
vllm_outputs_single_step = vllm_model.generate_greedy(
example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=vllm_outputs_multistep,
outputs_1_lst=vllm_outputs_single_step,
name_0="vllm_outputs_multistep",
name_1="vllm_outputs_single_step",
)


@multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
Expand Down
36 changes: 36 additions & 0 deletions tests/models/decoder_only/language/test_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,3 +283,39 @@ def test_state_cleanup(
except ValueError:
pytest.fail("Mamba inner state wasn't cleaned up between states, "
"could be related to finished_requests_ids")


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
def test_multistep(
vllm_runner,
model: str,
dtype: str,
example_prompts,
) -> None:
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
vllm_model.generate_greedy([example_prompts[0]] * 10, 1)


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float"])
@pytest.mark.parametrize("max_tokens", [64])
def test_multistep_correctness(vllm_runner, model: str, dtype: str,
max_tokens: int, example_prompts) -> None:
with vllm_runner(model, num_scheduler_steps=8,
max_num_seqs=2) as vllm_model:
vllm_outputs_multistep = vllm_model.generate_greedy(
example_prompts, max_tokens)

with vllm_runner(model, num_scheduler_steps=1,
max_num_seqs=2) as vllm_model:
vllm_outputs_single_step = vllm_model.generate_greedy(
example_prompts, max_tokens)

check_outputs_equal(
outputs_0_lst=vllm_outputs_multistep,
outputs_1_lst=vllm_outputs_single_step,
name_0="vllm_outputs_multistep",
name_1="vllm_outputs_single_step",
)
7 changes: 5 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,6 +300,9 @@ async def step_async(
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs

finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
Expand All @@ -311,13 +314,13 @@ async def step_async(
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
else:
finished_requests_ids = list()

assert seq_group_metadata_list is not None
assert scheduler_outputs is not None

if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()

# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
Expand Down
7 changes: 5 additions & 2 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,6 +1398,9 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs

finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()

# Maybe switch from async mode to sync mode
if not allow_async_output_proc and len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
Expand All @@ -1409,13 +1412,13 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
self._cache_scheduler_outputs_for_multi_step(
virtual_engine, seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc)
else:
finished_requests_ids = list()

assert seq_group_metadata_list is not None
assert scheduler_outputs is not None

if not scheduler_outputs.is_empty():
finished_requests_ids = self.scheduler[
virtual_engine].get_and_reset_finished_requests_ids()

# Check if we have a cached last_output from the previous iteration.
# For supporting PP this is probably the best way to pass the
Expand Down

0 comments on commit 197b448

Please sign in to comment.