Skip to content

Commit

Permalink
[Bugfix] Multi-sequence broken (#11898)
Browse files Browse the repository at this point in the history
Signed-off-by: Andy Lo <[email protected]>
  • Loading branch information
andylolu2 authored Jan 21, 2025
1 parent 132a132 commit 18fd4a8
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 39 deletions.
7 changes: 6 additions & 1 deletion tests/samplers/test_seeded_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_random_sample_with_seed(

sampling_params = SamplingParams(
# Parameters to ensure sufficient randomness
temperature=2.0,
temperature=3.0,
top_p=min(random.random() + 0.3, 1),
top_k=random.randint(5, 20),
n=random.randint(1, 10),
Expand Down Expand Up @@ -75,3 +75,8 @@ def test_random_sample_with_seed(
# verify requests with the same seed match
assert outputs[1] == outputs[4]
assert outputs[2] == outputs[5]

# verify generations within the same parallel sampling group differ
for output in outputs:
for sub_output_a, sub_output_b in combinations(output, 2):
assert sub_output_a != sub_output_b
2 changes: 1 addition & 1 deletion vllm/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,9 @@ def from_seq_group(
if seq_group.request_id in seq_id_to_seq_group:
group: SequenceGroupBase = seq_id_to_seq_group[
seq_group.request_id]
assembled_seq_group = group.maybe_assemble_group(seq_group)
if finished:
group.finish_seq(seq_group)
assembled_seq_group = group.maybe_assemble_group(seq_group)
if assembled_seq_group is None:
return None
return cls.from_seq_group(assembled_seq_group, use_cache,
Expand Down
89 changes: 52 additions & 37 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,7 +815,9 @@ def set_finished_time(self, time: Optional[float]) -> None:
def get_max_num_running_seqs(self) -> int:
"""The maximum number of sequences running in parallel in the remaining
lifetime of the request."""
return 0 if self.first_seq.is_finished() else 1
if self.is_single_seq:
return 0 if self.first_seq.is_finished() else 1
return self.num_seqs() - self.num_finished_seqs()

def get_seqs(
self,
Expand All @@ -824,7 +826,10 @@ def get_seqs(
if status is None:
return self.seqs

return self.seqs if self.first_seq.status == status else []
if self.is_single_seq:
return self.seqs if self.first_seq.status == status else []

return [seq for seq in self.seqs if seq.status == status]

def is_encoder_decoder(self) -> bool:
return self.encoder_seq is not None
Expand All @@ -833,19 +838,22 @@ def get_encoder_seq(self) -> Optional[Sequence]:
return self.encoder_seq

def get_finished_seqs(self) -> List[Sequence]:
return self.seqs if self.first_seq.is_finished() else []
if self.is_single_seq:
return self.seqs if self.first_seq.is_finished() else []

return [seq for seq in self.seqs if seq.is_finished()]

def update_num_computed_tokens(self, num_new_computed_tokens: int):
"""Update number of tokens computed so far."""
seq = self.first_seq
if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)
for seq in self.seqs:
if not seq.is_finished():
seq.data.update_num_computed_tokens(num_new_computed_tokens)

def get_num_uncomputed_tokens(self) -> int:
num_uncomputed_tokens = 0
seq = self.first_seq
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
for seq in self.seqs:
if not seq.is_finished():
num_uncomputed_tokens += seq.data.get_num_uncomputed_tokens()
return num_uncomputed_tokens

def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
Expand All @@ -860,10 +868,14 @@ def num_seqs(self, status: Optional[SequenceStatus] = None) -> int:
return len(self.get_seqs(status))

def num_finished_seqs(self) -> int:
return 1 if self.first_seq.is_finished() else 0
if self.is_single_seq:
return 1 if self.seqs[0].is_finished() else 0
return len(self.get_finished_seqs())

def is_finished(self) -> bool:
return self.first_seq.is_finished()
if self.is_single_seq:
return self.first_seq.is_finished()
return all(seq.is_finished() for seq in self.seqs)

def is_prefill(self) -> bool:
return self.first_seq.is_prefill()
Expand Down Expand Up @@ -1391,13 +1403,15 @@ class ParallelSampleSequenceGroup(SequenceGroupBase):
@staticmethod
def add_request(request_id: str, engine, params, **kwargs):
original_params = params
params = original_params.clone()
params.n = 1
group = ParallelSampleSequenceGroup(request_id)
seqs = []
for i in range(original_params.n):
request_id_i = f"{request_id}_parallel_sample_{i}"
group.seq_id_to_index[request_id_i] = i
params = copy.deepcopy(original_params)
params.n = 1
if params.seed is not None:
params.seed += i
seq_group = engine._add_processed_request(
request_id_i,
params=params,
Expand Down Expand Up @@ -1432,33 +1446,34 @@ def maybe_assemble_group(
self, seq_group: SequenceGroup) -> Optional[SequenceGroup]:

# in the streaming mode, we will return the assembled sequence
# for the first sequence, and then return None for the rest of
# sequences
# for the first remaining sequence, and then return None for the
# rest of sequences
if self.streaming:
if self.seq_id_to_index[seq_group.request_id] == 0:
first_remaining_id = next(iter(self.to_be_finished))
if seq_group.request_id == first_remaining_id:
return self.assembled_seq_group
return None

# in the non-streaming mode, we will return the assembled sequence
# once after all sequences finish, and then return None for the
# when the last sequences finishes, and then return None for the
# rest of the time

if len(self.to_be_finished) > 0:
return None

assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams)
if not self.output_produced:
self.output_produced = True
if params._real_n is not None:
# Get the top-n sequences.
n = params._real_n or params.n
seqs = self.assembled_seq_group.seqs
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
self.assembled_seq_group.seqs = top_n_seqs
return self.assembled_seq_group
if self.output_produced:
return None
if (len(self.to_be_finished) == 1
and seq_group.request_id in self.to_be_finished
and seq_group.is_finished()):
assert self.assembled_seq_group is not None
params = self.assembled_seq_group.sampling_params
assert isinstance(params, SamplingParams)
if not self.output_produced:
self.output_produced = True
if params._real_n is not None:
# Get the top-n sequences.
n = params._real_n or params.n
seqs = self.assembled_seq_group.seqs
sorting_key = lambda seq: seq.get_cumulative_logprob()
sorted_seqs = sorted(seqs, key=sorting_key, reverse=True)
top_n_seqs = sorted_seqs[:n]
self.assembled_seq_group.seqs = top_n_seqs
return self.assembled_seq_group
if self.output_produced:
return None
return None

0 comments on commit 18fd4a8

Please sign in to comment.