diff --git a/qa/L0_grpc/test.sh b/qa/L0_grpc/test.sh index c22390a82f..73b9710a71 100755 --- a/qa/L0_grpc/test.sh +++ b/qa/L0_grpc/test.sh @@ -625,6 +625,23 @@ elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then RET=1 fi +# Unknown protocol, not allowed +SERVER_ARGS="--model-repository=${MODELDIR} \ + --grpc-restricted-protocol=model-reposit,health:k1=v1 \ + --grpc-restricted-protocol=metadata,health:k2=v2" +run_server +EXPECTED_MSG="unknown restricted protocol 'model-reposit'" +if [ "$SERVER_PID" != "0" ]; then + echo -e "\n***\n*** Expect fail to start $SERVER\n***" + kill $SERVER_PID + wait $SERVER_PID + RET=1 +elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then + echo -e "\n***\n*** Failed. Expected ${EXPECTED_MSG} to be found in log\n***" + cat $SERVER_LOG + RET=1 +fi + # Test restricted protocols SERVER_ARGS="--model-repository=${MODELDIR} \ --grpc-restricted-protocol=model-repository:admin-key=admin-value \ diff --git a/qa/L0_http/http_restricted_api_test.py b/qa/L0_http/http_restricted_api_test.py index 2dcc5c6555..73b6a09645 100755 --- a/qa/L0_http/http_restricted_api_test.py +++ b/qa/L0_http/http_restricted_api_test.py @@ -48,8 +48,8 @@ def test_sanity(self): "simple", headers={"infer-key": "infer-value"} ) - # health, infer, model repository APIs are restricted. - # health and infer expects "infer-key : infer-value" header, + # metadata, infer, model repository APIs are restricted. + # metadata and infer expects "infer-key : infer-value" header, # model repository expected "admin-key : admin-value". def test_model_repository(self): with self.assertRaisesRegex(InferenceServerException, "This API is restricted"): @@ -64,6 +64,11 @@ def test_model_repository(self): self.model_name_, headers={"admin-key": "admin-value"} ) + def test_metadata(self): + with self.assertRaisesRegex(InferenceServerException, "This API is restricted"): + self.client_.get_server_metadata() + self.client_.get_server_metadata({"infer-key": "infer-value"}) + def test_infer(self): # setup inputs = [ diff --git a/qa/L0_http/test.sh b/qa/L0_http/test.sh index 04097cc5d2..c9ad809525 100755 --- a/qa/L0_http/test.sh +++ b/qa/L0_http/test.sh @@ -671,6 +671,8 @@ kill $SERVER_PID wait $SERVER_PID ### Test Restricted APIs ### +### Repeated API not allowed + MODELDIR="`pwd`/models" SERVER_ARGS="--model-repository=${MODELDIR} --http-restricted-api=model-repository,health:k1=v1 \ @@ -688,11 +690,32 @@ elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then RET=1 fi +### Test Unknown Restricted API### +### Unknown API not allowed + +MODELDIR="`pwd`/models" +SERVER_ARGS="--model-repository=${MODELDIR} + --http-restricted-api=model-reposit,health:k1=v1 \ + --http-restricted-api=metadata,health:k2=v2" +run_server +EXPECTED_MSG="unknown restricted api 'model-reposit'" +if [ "$SERVER_PID" != "0" ]; then + echo -e "\n***\n*** Expect fail to start $SERVER\n***" + kill $SERVER_PID + wait $SERVER_PID + RET=1 +elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then + echo -e "\n***\n*** Failed. Expected ${EXPECTED_MSG} to be found in log\n***" + cat $SERVER_LOG + RET=1 +fi + ### Test Restricted APIs ### +### Restricted model-repository, metadata, and inference SERVER_ARGS="--model-repository=${MODELDIR} \ --http-restricted-api=model-repository:admin-key=admin-value \ - --http-restricted-api=inference:infer-key=infer-value" + --http-restricted-api=inference,metadata:infer-key=infer-value" run_server if [ "$SERVER_PID" == "0" ]; then echo -e "\n***\n*** Failed to start $SERVER\n***" diff --git a/src/command_line_parser.cc b/src/command_line_parser.cc index 5b57182952..f771fc343d 100644 --- a/src/command_line_parser.cc +++ b/src/command_line_parser.cc @@ -1919,7 +1919,7 @@ void TritonParser::ParseRestrictedFeatureOption( const std::string& arg, const std::string& option_name, const std::string& key_prefix, const std::string& feature_type, - RestrictedFeatureMap& restricted_features) + RestrictedFeatures& restricted_features) { const auto& parsed_tuple = ParseGenericConfigOption(arg, ":", "=", option_name, "config name"); @@ -1929,14 +1929,24 @@ TritonParser::ParseRestrictedFeatureOption( const auto& value = std::get<2>(parsed_tuple); for (const auto& feature : features) { - if (restricted_features.count(feature)) { + const auto& category = RestrictedFeatures::ToCategory(feature); + + if (category == RestrictedCategory::INVALID) { + std::stringstream ss; + ss << "unknown restricted " << feature_type << " '" << feature << "' " + << std::endl; + throw ParseException(ss.str()); + } + + if (restricted_features.IsRestricted(category)) { // restricted feature can only be in one group std::stringstream ss; ss << "restricted " << feature_type << " '" << feature << "' can not be specified in multiple config groups" << std::endl; throw ParseException(ss.str()); } - restricted_features[feature] = std::make_pair(key_prefix + key, value); + restricted_features.Insert( + category, std::make_pair(key_prefix + key, value)); } } diff --git a/src/command_line_parser.h b/src/command_line_parser.h index bb306f366e..bb4368fb84 100644 --- a/src/command_line_parser.h +++ b/src/command_line_parser.h @@ -35,6 +35,7 @@ #include #include +#include "restricted_features.h" #include "triton/common/logging.h" #include "triton/core/tritonserver.h" #ifdef TRITON_ENABLE_GRPC @@ -189,7 +190,7 @@ struct TritonServerParameters { std::string http_forward_header_pattern_; // The number of threads to initialize for the HTTP front-end. int http_thread_cnt_{8}; - RestrictedFeatureMap http_restricted_apis_{}; + RestrictedFeatures http_restricted_apis_{}; #endif // TRITON_ENABLE_HTTP #ifdef TRITON_ENABLE_GRPC @@ -285,7 +286,7 @@ class TritonParser { void ParseRestrictedFeatureOption( const std::string& arg, const std::string& option_name, const std::string& header_prefix, const std::string& feature_name, - RestrictedFeatureMap& restricted_features); + RestrictedFeatures& restricted_features); #ifdef TRITON_ENABLE_TRACING TRITONSERVER_InferenceTraceLevel ParseTraceLevelOption(std::string arg); InferenceTraceMode ParseTraceModeOption(std::string arg); diff --git a/src/common.h b/src/common.h index 2833727221..0704daa3a1 100644 --- a/src/common.h +++ b/src/common.h @@ -25,6 +25,7 @@ // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #pragma once +#include #include #include #include @@ -171,8 +172,4 @@ bool Contains(const std::vector& vec, const std::string& str); /// \return The joint string. std::string Join(const std::vector& vec, const std::string& delim); -using RestrictedFeature = std::pair; - -using RestrictedFeatureMap = std::map; - }} // namespace triton::server diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index eac05f93c9..f9dbd5c016 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -253,7 +253,7 @@ class CommonHandler : public HandlerBase { inference::GRPCInferenceService::AsyncService* service, ::grpc::health::v1::Health::AsyncService* health_service, ::grpc::ServerCompletionQueue* cq, - const RestrictedFeatureMap& restricted_keys); + const RestrictedFeatures& restricted_keys); // Descriptive name of of the handler. const std::string& Name() const { return name_; } @@ -298,13 +298,9 @@ class CommonHandler : public HandlerBase { ::grpc::health::v1::Health::AsyncService* health_service_; ::grpc::ServerCompletionQueue* cq_; std::unique_ptr thread_; - const RestrictedFeatureMap& restricted_keys_; - static std::pair empty_restricted_key_; + const RestrictedFeatures& restricted_keys_; }; -std::pair CommonHandler::empty_restricted_key_{ - "", ""}; - CommonHandler::CommonHandler( const std::string& name, const std::shared_ptr& tritonserver, @@ -313,7 +309,7 @@ CommonHandler::CommonHandler( inference::GRPCInferenceService::AsyncService* service, ::grpc::health::v1::Health::AsyncService* health_service, ::grpc::ServerCompletionQueue* cq, - const RestrictedFeatureMap& restricted_keys) + const RestrictedFeatures& restricted_keys) : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), trace_manager_(trace_manager), service_(service), health_service_(health_service), cq_(cq), @@ -438,9 +434,8 @@ CommonHandler::RegisterServerLive() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("health"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ServerLiveRequest, inference::ServerLiveResponse>( @@ -475,9 +470,8 @@ CommonHandler::RegisterServerReady() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("health"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ServerReadyRequest, inference::ServerReadyResponse>( @@ -523,9 +517,8 @@ CommonHandler::RegisterHealthCheck() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("health"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); new CommonCallData< ::grpc::ServerAsyncResponseWriter< ::grpc::health::v1::HealthCheckResponse>, @@ -568,9 +561,8 @@ CommonHandler::RegisterModelReady() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("health"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ModelReadyRequest, inference::ModelReadyResponse>( @@ -647,9 +639,8 @@ CommonHandler::RegisterServerMetadata() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("metadata"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ServerMetadataRequest, inference::ServerMetadataResponse>( @@ -816,9 +807,8 @@ CommonHandler::RegisterModelMetadata() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("metadata"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ModelMetadataRequest, inference::ModelMetadataResponse>( @@ -870,9 +860,8 @@ CommonHandler::RegisterModelConfig() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("model-config"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ModelConfigRequest, inference::ModelConfigResponse>( @@ -1201,9 +1190,8 @@ CommonHandler::RegisterModelStatistics() #endif }; - const auto it = restricted_keys_.find("statistics"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::STATISTICS); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ModelStatisticsRequest, inference::ModelStatisticsResponse>( @@ -1477,9 +1465,8 @@ CommonHandler::RegisterTrace() #endif }; - const auto it = restricted_keys_.find("trace"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::TRACE); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::TraceSettingRequest, inference::TraceSettingResponse>( @@ -1687,9 +1674,8 @@ CommonHandler::RegisterLogging() #endif }; - const auto it = restricted_keys_.find("logging"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::LOGGING); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::LogSettingsRequest, inference::LogSettingsResponse>( @@ -1760,9 +1746,8 @@ CommonHandler::RegisterSystemSharedMemoryStatus() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryStatusResponse>, @@ -1799,9 +1784,8 @@ CommonHandler::RegisterSystemSharedMemoryRegister() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryRegisterResponse>, @@ -1843,9 +1827,8 @@ CommonHandler::RegisterSystemSharedMemoryUnregister() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryUnregisterResponse>, @@ -1911,9 +1894,8 @@ CommonHandler::RegisterCudaSharedMemoryStatus() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryStatusResponse>, @@ -1962,9 +1944,8 @@ CommonHandler::RegisterCudaSharedMemoryRegister() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryRegisterResponse>, @@ -2004,10 +1985,9 @@ CommonHandler::RegisterCudaSharedMemoryUnregister() GrpcStatusUtil::Create(status, err); TRITONSERVER_ErrorDelete(err); }; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryUnregisterResponse>, @@ -2111,9 +2091,8 @@ CommonHandler::RegisterRepositoryIndex() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("model-repository"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::RepositoryIndexRequest, inference::RepositoryIndexResponse>( @@ -2223,9 +2202,8 @@ CommonHandler::RegisterRepositoryModelLoad() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("model-repository"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::RepositoryModelLoadRequest, @@ -2292,9 +2270,8 @@ CommonHandler::RegisterRepositoryModelUnload() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("model-repository"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::RepositoryModelUnloadResponse>, @@ -2417,11 +2394,8 @@ Server::Server( // [FIXME] "register" logic is different for infer // Handler for model inference requests. - const auto it = options.restricted_protocols_.find("inference"); std::pair restricted_kv = - (it == options.restricted_protocols_.end()) - ? std::pair{"", ""} - : it->second; + options.restricted_protocols_.Get(RestrictedCategory::INFERENCE); for (int i = 0; i < REGISTER_GRPC_INFER_THREAD_COUNT; ++i) { model_infer_handlers_.emplace_back(new ModelInferHandler( "ModelInferHandler", tritonserver_, trace_manager_, shm_manager_, diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index 3aeae783e1..197cb72eea 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -29,6 +29,7 @@ #include +#include "../restricted_features.h" #include "../shared_memory_manager.h" #include "../tracer.h" #include "grpc_handler.h" @@ -84,7 +85,7 @@ struct Options { // requests doesn't exceed this value there will be no // allocation/deallocation of request/response objects. int infer_allocation_pool_size_{8}; - RestrictedFeatureMap restricted_protocols_; + RestrictedFeatures restricted_protocols_; std::string forward_header_pattern_; }; diff --git a/src/http_server.cc b/src/http_server.cc index 7fea66330d..18b5e3420b 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -80,14 +80,15 @@ namespace triton { namespace server { return; \ } while (false) -#define RETURN_AND_RESPOND_IF_RESTRICTED(REQ, API_KEY) \ - do { \ - static auto const is_restricted_api = restricted_apis_.count(API_KEY); \ - static auto const restricted_api = restricted_apis_.find(API_KEY); \ - if (is_restricted_api && \ - RespondIfRestricted(REQ, restricted_api->second)) { \ - return; \ - } \ +#define RETURN_AND_RESPOND_IF_RESTRICTED( \ + REQ, RESTRICTED_CATEGORY, RESTRICTED_APIS) \ + do { \ + auto const& is_restricted_api = \ + RESTRICTED_APIS.IsRestricted(RESTRICTED_CATEGORY); \ + auto const& restriction = RESTRICTED_APIS.Get(RESTRICTED_CATEGORY); \ + if (is_restricted_api && RespondIfRestricted(REQ, restriction)) { \ + return; \ + } \ } while (false) @@ -1050,7 +1051,7 @@ HTTPAPIServer::HTTPAPIServer( const std::shared_ptr& shm_manager, const int32_t port, const bool reuse_port, const std::string& address, const std::string& header_forward_pattern, const int thread_cnt, - const RestrictedFeatureMap& restricted_apis) + const RestrictedFeatures& restricted_apis) : HTTPServer(port, reuse_port, address, header_forward_pattern, thread_cnt), server_(server), trace_manager_(trace_manager), shm_manager_(shm_manager), allocator_(nullptr), server_regex_(R"(/v2(?:/health/(live|ready))?)"), @@ -1283,7 +1284,8 @@ HTTPAPIServer::InferResponseFree( void HTTPAPIServer::HandleServerHealth(evhtp_request_t* req, const std::string& kind) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "health"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::HEALTH, restricted_apis_); if (req->method != htp_method_GET) { RETURN_AND_RESPOND_WITH_ERR( @@ -1309,7 +1311,8 @@ void HTTPAPIServer::HandleRepositoryIndex( evhtp_request_t* req, const std::string& repository_name) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "model-repository"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::MODEL_REPOSITORY, restricted_apis_); AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_POST) { @@ -1378,7 +1381,8 @@ HTTPAPIServer::HandleRepositoryControl( evhtp_request_t* req, const std::string& repository_name, const std::string& model_name, const std::string& action) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "model-repository"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::MODEL_REPOSITORY, restricted_apis_); AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_POST) { @@ -1533,7 +1537,8 @@ HTTPAPIServer::HandleModelReady( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "health"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::HEALTH, restricted_apis_); if (req->method != htp_method_GET) { RETURN_AND_RESPOND_WITH_ERR( @@ -1569,7 +1574,8 @@ HTTPAPIServer::HandleModelMetadata( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "metadata"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::METADATA, restricted_apis_); AddContentTypeHeader(req, "application/json"); @@ -1641,7 +1647,8 @@ HTTPAPIServer::HandleModelConfig( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "model-config"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::MODEL_CONFIG, restricted_apis_); AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_GET) { @@ -1668,7 +1675,8 @@ HTTPAPIServer::HandleModelStats( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "statistics"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::STATISTICS, restricted_apis_); AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_GET) { @@ -1712,7 +1720,8 @@ HTTPAPIServer::HandleModelStats( void HTTPAPIServer::HandleTrace(evhtp_request_t* req, const std::string& model_name) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "trace"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::TRACE, restricted_apis_); AddContentTypeHeader(req, "application/json"); if ((req->method != htp_method_GET) && (req->method != htp_method_POST)) { @@ -1967,7 +1976,8 @@ HTTPAPIServer::HandleTrace(evhtp_request_t* req, const std::string& model_name) void HTTPAPIServer::HandleLogging(evhtp_request_t* req) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "logging"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::LOGGING, restricted_apis_); AddContentTypeHeader(req, "application/json"); if ((req->method != htp_method_GET) && (req->method != htp_method_POST)) { @@ -2118,7 +2128,8 @@ HTTPAPIServer::HandleLogging(evhtp_request_t* req) void HTTPAPIServer::HandleServerMetadata(evhtp_request_t* req) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "metadata"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::METADATA, restricted_apis_); AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_GET) { @@ -2141,7 +2152,8 @@ HTTPAPIServer::HandleSystemSharedMemory( evhtp_request_t* req, const std::string& region_name, const std::string& action) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "shared-memory"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::SHARED_MEMORY, restricted_apis_); AddContentTypeHeader(req, "application/json"); if ((action == "status") && (req->method != htp_method_GET)) { @@ -2246,7 +2258,8 @@ HTTPAPIServer::HandleCudaSharedMemory( evhtp_request_t* req, const std::string& region_name, const std::string& action) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "shared-memory"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::SHARED_MEMORY, restricted_apis_); AddContentTypeHeader(req, "application/json"); if ((action == "status") && (req->method != htp_method_GET)) { @@ -3100,7 +3113,8 @@ HTTPAPIServer::HandleGenerate( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str, bool streaming) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "inference"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::INFERENCE, restricted_apis_); AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_POST) { @@ -3450,7 +3464,8 @@ HTTPAPIServer::HandleInfer( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str) { - RETURN_AND_RESPOND_IF_RESTRICTED(req, "inference"); + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::INFERENCE, restricted_apis_); if (req->method != htp_method_POST) { RETURN_AND_RESPOND_WITH_ERR( @@ -4529,7 +4544,7 @@ HTTPAPIServer::Create( const std::shared_ptr& shm_manager, const int32_t port, const bool reuse_port, const std::string& address, const std::string& header_forward_pattern, const int thread_cnt, - const RestrictedFeatureMap& restricted_features, + const RestrictedFeatures& restricted_features, std::unique_ptr* http_server) { http_server->reset(new HTTPAPIServer( @@ -4544,10 +4559,10 @@ HTTPAPIServer::Create( bool HTTPAPIServer::RespondIfRestricted( - evhtp_request_t* req, const RestrictedFeature& restricted_api) + evhtp_request_t* req, const Restriction& restriction) { - auto header = restricted_api.first; - auto expected_value = restricted_api.second; + auto header = restriction.first; + auto expected_value = restriction.second; const char* actual_value = evhtp_kv_find(req->headers_in, header.c_str()); if ((actual_value == nullptr) || (actual_value != expected_value)) { EVBufferAddErrorJson( diff --git a/src/http_server.h b/src/http_server.h index 5a127777c7..53c5daea87 100644 --- a/src/http_server.h +++ b/src/http_server.h @@ -38,6 +38,7 @@ #include "common.h" #include "data_compressor.h" +#include "restricted_features.h" #include "shared_memory_manager.h" #include "tracer.h" #include "triton/common/logging.h" @@ -151,7 +152,7 @@ class HTTPAPIServer : public HTTPServer { const std::shared_ptr& smb_manager, const int32_t port, const bool reuse_port, const std::string& address, const std::string& header_forward_pattern, const int thread_cnt, - const RestrictedFeatureMap& restricted_apis, + const RestrictedFeatures& restricted_apis, std::unique_ptr* http_server); virtual ~HTTPAPIServer(); @@ -360,7 +361,7 @@ class HTTPAPIServer : public HTTPServer { const std::shared_ptr& shm_manager, const int32_t port, const bool reuse_port, const std::string& address, const std::string& header_forward_pattern, const int thread_cnt, - const RestrictedFeatureMap& restricted_apis); + const RestrictedFeatures& restricted_apis); virtual void Handle(evhtp_request_t* req) override; // [FIXME] extract to "infer" class virtual std::unique_ptr CreateInferRequest( @@ -545,9 +546,9 @@ class HTTPAPIServer : public HTTPServer { parameters_field, new MappingSchema(MappingSchema::Kind::MAPPING_SCHEMA, true)); } - const RestrictedFeatureMap& restricted_apis_{}; + const RestrictedFeatures& restricted_apis_{}; bool RespondIfRestricted( - evhtp_request_t* req, const RestrictedFeature& restricted_api); + evhtp_request_t* req, const Restriction& restriction); }; }} // namespace triton::server diff --git a/src/sagemaker_server.h b/src/sagemaker_server.h index 010561c750..7f1ef7e050 100644 --- a/src/sagemaker_server.h +++ b/src/sagemaker_server.h @@ -73,7 +73,7 @@ class SagemakerAPIServer : public HTTPAPIServer { : HTTPAPIServer( server, trace_manager, shm_manager, port, false /* reuse_port */, address, "" /* header_forward_pattern */, thread_cnt, - RestrictedFeatureMap()), + RestrictedFeatures()), ping_regex_(R"(/ping)"), invocations_regex_(R"(/invocations)"), models_regex_(R"(/models(?:/)?([^/]+)?(/invoke)?)"), model_path_regex_( diff --git a/src/vertex_ai_server.cc b/src/vertex_ai_server.cc index c9a990a24d..0cf4d14977 100644 --- a/src/vertex_ai_server.cc +++ b/src/vertex_ai_server.cc @@ -46,7 +46,7 @@ VertexAiAPIServer::VertexAiAPIServer( : HTTPAPIServer( server, trace_manager, shm_manager, port, false /* reuse_port */, address, "" /* header_forward_pattern */, thread_cnt, - RestrictedFeatureMap()), + RestrictedFeatures()), prediction_regex_(prediction_route), health_regex_(health_route), health_mode_("ready"), model_name_(default_model_name), model_version_str_("")