diff --git a/src/grpc/infer_handler.cc b/src/grpc/infer_handler.cc index d4d8092142..bfd3c9b2e2 100644 --- a/src/grpc/infer_handler.cc +++ b/src/grpc/infer_handler.cc @@ -973,6 +973,8 @@ ModelInferHandler::InferResponseComplete( return; } + state->context_->EraseInflightState(state); + #ifdef TRITON_ENABLE_TRACING state->trace_timestamps_.emplace_back(std::make_pair( "INFER_RESPONSE_COMPLETE", TraceManager::CaptureTimestamp())); @@ -987,7 +989,6 @@ ModelInferHandler::InferResponseComplete( "deleting GRPC inference response"); state->step_ = Steps::CANCELLED; - state->context_->EraseInflightState(state); LOG_VERBOSE(1) << "ModelInferHandler::InferResponseComplete, " << state->unique_id_ diff --git a/src/grpc/infer_handler.h b/src/grpc/infer_handler.h index 0ba9e6235f..d9bf17a068 100644 --- a/src/grpc/infer_handler.h +++ b/src/grpc/infer_handler.h @@ -640,9 +640,9 @@ class InferHandlerState { void GrpcContextAsyncNotifyWhenDone(InferHandlerStateType* state) { - InferHandlerStateType* wrapped_state = - new InferHandlerStateType(Steps::WAITING_NOTIFICATION, state); - ctx_->AsyncNotifyWhenDone(wrapped_state); + notify_state_ = std::unique_ptr( + new InferHandlerStateType(Steps::WAITING_NOTIFICATION, state)); + ctx_->AsyncNotifyWhenDone(notify_state_.get()); } void SetReceivedNotification(bool value) { received_notification_ = true; } @@ -666,8 +666,12 @@ class InferHandlerState { all_states_.insert(state); } - // Adds the state object created on this context - void EraseState(InferHandlerStateType* state) { all_states_.erase(state); } + // Erases the state object created on this context + void EraseState(InferHandlerStateType* state) + { + EraseInflightState(state); + all_states_.erase(state); + } bool HandleCompletion() { @@ -975,6 +979,10 @@ class InferHandlerState { // True if there is an ongoing write to the grpc stream std::atomic ongoing_write_; + // The state object that is sent to grpc async notification + // for tracking the gRPC stream. + std::unique_ptr notify_state_; + // Tracks whether the async notification has been delivered by // completion queue. bool received_notification_; @@ -1274,7 +1282,6 @@ InferHandler< state->context_->SetReceivedNotification(true); LOG_VERBOSE(1) << "Received notification for " << Name() << ", " << state->unique_id_; - delete state_wrapper; } LOG_VERBOSE(2) << "Grpc::CQ::Next() " << state->context_->DebugString(state); diff --git a/src/grpc/stream_infer_handler.cc b/src/grpc/stream_infer_handler.cc index 1934bf4c98..9e564d8322 100644 --- a/src/grpc/stream_infer_handler.cc +++ b/src/grpc/stream_infer_handler.cc @@ -576,6 +576,15 @@ ModelStreamInferHandler::StreamInferResponseComplete( } } + // If receiving the final callback then erase the state from the inflight + // state data structure to prevent cancellation being called on the request. + // Also make sure that if this state was sent to gRPC async notification + // mechanism then the state is not removed as it would be needed for handling + // the cancellation if detected. + if (state->complete_ && (!state->IsAsyncNotifyState())) { + state->context_->EraseInflightState(state); + } + if (state->IsGrpcContextCancelled()) { std::lock_guard lock(state->step_mtx_); // Clean-up the received response object. @@ -593,7 +602,6 @@ ModelStreamInferHandler::StreamInferResponseComplete( // that state object can be released. if (state->complete_) { state->step_ = Steps::CANCELLED; - state->context_->EraseInflightState(state); state->context_->PutTaskBackToQueue(state); } @@ -692,7 +700,6 @@ ModelStreamInferHandler::StreamInferResponseComplete( // that state object can be released. if (state->complete_) { state->step_ = Steps::CANCELLED; - state->context_->EraseInflightState(state); state->context_->PutTaskBackToQueue(state); }