Skip to content

Commit

Permalink
Fix notify state destruction and inflight states tracking (#6451) (#6457
Browse files Browse the repository at this point in the history
)

* Ensure notify_state_ gets properly destructed

* Fix inflight state tracking to properly erase states

* Prevent removing the notify_state from being erased

* Wrap notify_state_ object within unique_ptr
  • Loading branch information
tanmayv25 authored Oct 19, 2023
1 parent 049cce5 commit f4eb6bd
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/grpc/infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,8 @@ ModelInferHandler::InferResponseComplete(
return;
}

state->context_->EraseInflightState(state);

#ifdef TRITON_ENABLE_TRACING
state->trace_timestamps_.emplace_back(std::make_pair(
"INFER_RESPONSE_COMPLETE", TraceManager::CaptureTimestamp()));
Expand All @@ -987,7 +989,6 @@ ModelInferHandler::InferResponseComplete(
"deleting GRPC inference response");

state->step_ = Steps::CANCELLED;
state->context_->EraseInflightState(state);

LOG_VERBOSE(1) << "ModelInferHandler::InferResponseComplete, "
<< state->unique_id_
Expand Down
19 changes: 13 additions & 6 deletions src/grpc/infer_handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,9 @@ class InferHandlerState {

void GrpcContextAsyncNotifyWhenDone(InferHandlerStateType* state)
{
InferHandlerStateType* wrapped_state =
new InferHandlerStateType(Steps::WAITING_NOTIFICATION, state);
ctx_->AsyncNotifyWhenDone(wrapped_state);
notify_state_ = std::unique_ptr<InferHandlerStateType>(
new InferHandlerStateType(Steps::WAITING_NOTIFICATION, state));
ctx_->AsyncNotifyWhenDone(notify_state_.get());
}

void SetReceivedNotification(bool value) { received_notification_ = true; }
Expand All @@ -666,8 +666,12 @@ class InferHandlerState {
all_states_.insert(state);
}

// Adds the state object created on this context
void EraseState(InferHandlerStateType* state) { all_states_.erase(state); }
// Erases the state object created on this context
void EraseState(InferHandlerStateType* state)
{
EraseInflightState(state);
all_states_.erase(state);
}

bool HandleCompletion()
{
Expand Down Expand Up @@ -975,6 +979,10 @@ class InferHandlerState {
// True if there is an ongoing write to the grpc stream
std::atomic<bool> ongoing_write_;

// The state object that is sent to grpc async notification
// for tracking the gRPC stream.
std::unique_ptr<InferHandlerState> notify_state_;

// Tracks whether the async notification has been delivered by
// completion queue.
bool received_notification_;
Expand Down Expand Up @@ -1274,7 +1282,6 @@ InferHandler<
state->context_->SetReceivedNotification(true);
LOG_VERBOSE(1) << "Received notification for " << Name() << ", "
<< state->unique_id_;
delete state_wrapper;
}
LOG_VERBOSE(2) << "Grpc::CQ::Next() "
<< state->context_->DebugString(state);
Expand Down
11 changes: 9 additions & 2 deletions src/grpc/stream_infer_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,15 @@ ModelStreamInferHandler::StreamInferResponseComplete(
}
}

// If receiving the final callback then erase the state from the inflight
// state data structure to prevent cancellation being called on the request.
// Also make sure that if this state was sent to gRPC async notification
// mechanism then the state is not removed as it would be needed for handling
// the cancellation if detected.
if (state->complete_ && (!state->IsAsyncNotifyState())) {
state->context_->EraseInflightState(state);
}

if (state->IsGrpcContextCancelled()) {
std::lock_guard<std::recursive_mutex> lock(state->step_mtx_);
// Clean-up the received response object.
Expand All @@ -593,7 +602,6 @@ ModelStreamInferHandler::StreamInferResponseComplete(
// that state object can be released.
if (state->complete_) {
state->step_ = Steps::CANCELLED;
state->context_->EraseInflightState(state);
state->context_->PutTaskBackToQueue(state);
}

Expand Down Expand Up @@ -692,7 +700,6 @@ ModelStreamInferHandler::StreamInferResponseComplete(
// that state object can be released.
if (state->complete_) {
state->step_ = Steps::CANCELLED;
state->context_->EraseInflightState(state);
state->context_->PutTaskBackToQueue(state);
}

Expand Down

0 comments on commit f4eb6bd

Please sign in to comment.