Skip to content

Commit

Permalink
[Bugfix] Add synchronize to prevent possible data race (vllm-project#…
Browse files Browse the repository at this point in the history
…6788)

Co-authored-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
tlrmchlsmth and LucasWilkinson authored Jul 25, 2024
1 parent 65b1f12 commit 95db75d
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,13 @@ def graph_capture(
ca_comm = self.ca_comm
maybe_ca_context = nullcontext(
) if ca_comm is None else ca_comm.capture()

# ensure all initialization operations complete before attempting to
# capture the graph on another stream
curr_stream = torch.cuda.current_stream()
if curr_stream != stream:
stream.wait_stream(curr_stream)

with torch.cuda.stream(stream), maybe_ca_context:
# In graph mode, we have to be very careful about the collective
# operations. The current status is:
Expand Down

0 comments on commit 95db75d

Please sign in to comment.