diff --git a/qa/L0_grpc/test.sh b/qa/L0_grpc/test.sh index cf43e324a2..73b9710a71 100755 --- a/qa/L0_grpc/test.sh +++ b/qa/L0_grpc/test.sh @@ -613,7 +613,24 @@ SERVER_ARGS="--model-repository=${MODELDIR} \ --grpc-restricted-protocol=model-repository,health:k1=v1 \ --grpc-restricted-protocol=metadata,health:k2=v2" run_server -EXPECTED_MSG="protocol 'health' can not be specified in multiple config group" +EXPECTED_MSG="protocol 'health' can not be specified in multiple config groups" +if [ "$SERVER_PID" != "0" ]; then + echo -e "\n***\n*** Expect fail to start $SERVER\n***" + kill $SERVER_PID + wait $SERVER_PID + RET=1 +elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then + echo -e "\n***\n*** Failed. Expected ${EXPECTED_MSG} to be found in log\n***" + cat $SERVER_LOG + RET=1 +fi + +# Unknown protocol, not allowed +SERVER_ARGS="--model-repository=${MODELDIR} \ + --grpc-restricted-protocol=model-reposit,health:k1=v1 \ + --grpc-restricted-protocol=metadata,health:k2=v2" +run_server +EXPECTED_MSG="unknown restricted protocol 'model-reposit'" if [ "$SERVER_PID" != "0" ]; then echo -e "\n***\n*** Expect fail to start $SERVER\n***" kill $SERVER_PID diff --git a/qa/L0_http/http_restricted_api_test.py b/qa/L0_http/http_restricted_api_test.py new file mode 100755 index 0000000000..e5e3d5fd2d --- /dev/null +++ b/qa/L0_http/http_restricted_api_test.py @@ -0,0 +1,94 @@ +#!/usr/bin/python +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys + +sys.path.append("../common") + +import unittest + +import numpy as np +import tritonclient.http as tritonhttpclient +from tritonclient.utils import InferenceServerException + + +class RestrictedAPITest(unittest.TestCase): + def setUp(self): + self.model_name_ = "simple" + self.client_ = tritonhttpclient.InferenceServerClient("localhost:8000") + + # Other unspecified APIs should not be restricted + def test_sanity(self): + self.client_.get_inference_statistics("simple") + self.client_.get_inference_statistics( + "simple", headers={"infer-key": "infer-value"} + ) + + # metadata, infer, model repository APIs are restricted. + # metadata and infer expects "infer-key : infer-value" header, + # model repository expected "admin-key : admin-value". + def test_model_repository(self): + with self.assertRaisesRegex(InferenceServerException, "This API is restricted"): + self.client_.unload_model( + self.model_name_, headers={"infer-key": "infer-value"} + ) + # Request go through and get actual transaction error + with self.assertRaisesRegex( + InferenceServerException, "explicit model load / unload is not allowed" + ): + self.client_.unload_model( + self.model_name_, headers={"admin-key": "admin-value"} + ) + + def test_metadata(self): + with self.assertRaisesRegex(InferenceServerException, "This API is restricted"): + self.client_.get_server_metadata() + self.client_.get_server_metadata({"infer-key": "infer-value"}) + + def test_infer(self): + # setup + inputs = [ + tritonhttpclient.InferInput("INPUT0", [1, 16], "INT32"), + tritonhttpclient.InferInput("INPUT1", [1, 16], "INT32"), + ] + inputs[0].set_data_from_numpy(np.ones(shape=(1, 16), dtype=np.int32)) + inputs[1].set_data_from_numpy(np.ones(shape=(1, 16), dtype=np.int32)) + + # This test only care if the request goes through + with self.assertRaisesRegex(InferenceServerException, "This API is restricted"): + _ = self.client_.infer( + model_name=self.model_name_, inputs=inputs, headers={"test": "1"} + ) + self.client_.infer( + model_name=self.model_name_, + inputs=inputs, + headers={"infer-key": "infer-value"}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/qa/L0_http/test.sh b/qa/L0_http/test.sh index d224faeca4..c9ad809525 100755 --- a/qa/L0_http/test.sh +++ b/qa/L0_http/test.sh @@ -44,6 +44,7 @@ RET=0 CLIENT_PLUGIN_TEST="./http_client_plugin_test.py" BASIC_AUTH_TEST="./http_basic_auth_test.py" +RESTRICTED_API_TEST="./http_restricted_api_test.py" NGINX_CONF="./nginx.conf" # On windows the paths invoked by the script (running in WSL) must use # /mnt/c when needed but the paths on the tritonserver command-line @@ -588,13 +589,15 @@ kill $SERVER_PID wait $SERVER_PID # Run python unit test -rm -r ${MODELDIR}/* +MODELDIR=python_unit_test_models +mkdir -p $MODELDIR +rm -rf ${MODELDIR}/* cp -r $DATADIR/qa_identity_model_repository/onnx_zero_1_float32 ${MODELDIR}/. cp -r $DATADIR/qa_identity_model_repository/onnx_zero_1_object ${MODELDIR}/. cp -r $DATADIR/qa_identity_model_repository/onnx_zero_1_float16 ${MODELDIR}/. cp -r $DATADIR/qa_identity_model_repository/onnx_zero_3_float32 ${MODELDIR}/. cp -r ${MODELDIR}/onnx_zero_1_object ${MODELDIR}/onnx_zero_1_object_1_element && \ - (cd models/onnx_zero_1_object_1_element && \ + (cd $MODELDIR/onnx_zero_1_object_1_element && \ sed -i "s/onnx_zero_1_object/onnx_zero_1_object_1_element/" config.pbtxt && \ sed -i "0,/-1/{s/-1/1/}" config.pbtxt) @@ -667,6 +670,70 @@ set -e kill $SERVER_PID wait $SERVER_PID +### Test Restricted APIs ### +### Repeated API not allowed + +MODELDIR="`pwd`/models" +SERVER_ARGS="--model-repository=${MODELDIR} + --http-restricted-api=model-repository,health:k1=v1 \ + --http-restricted-api=metadata,health:k2=v2" +run_server +EXPECTED_MSG="api 'health' can not be specified in multiple config groups" +if [ "$SERVER_PID" != "0" ]; then + echo -e "\n***\n*** Expect fail to start $SERVER\n***" + kill $SERVER_PID + wait $SERVER_PID + RET=1 +elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then + echo -e "\n***\n*** Failed. Expected ${EXPECTED_MSG} to be found in log\n***" + cat $SERVER_LOG + RET=1 +fi + +### Test Unknown Restricted API### +### Unknown API not allowed + +MODELDIR="`pwd`/models" +SERVER_ARGS="--model-repository=${MODELDIR} + --http-restricted-api=model-reposit,health:k1=v1 \ + --http-restricted-api=metadata,health:k2=v2" +run_server +EXPECTED_MSG="unknown restricted api 'model-reposit'" +if [ "$SERVER_PID" != "0" ]; then + echo -e "\n***\n*** Expect fail to start $SERVER\n***" + kill $SERVER_PID + wait $SERVER_PID + RET=1 +elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then + echo -e "\n***\n*** Failed. Expected ${EXPECTED_MSG} to be found in log\n***" + cat $SERVER_LOG + RET=1 +fi + +### Test Restricted APIs ### +### Restricted model-repository, metadata, and inference + +SERVER_ARGS="--model-repository=${MODELDIR} \ + --http-restricted-api=model-repository:admin-key=admin-value \ + --http-restricted-api=inference,metadata:infer-key=infer-value" +run_server +if [ "$SERVER_PID" == "0" ]; then + echo -e "\n***\n*** Failed to start $SERVER\n***" + cat $SERVER_LOG + exit 1 +fi +set +e + +python $RESTRICTED_API_TEST RestrictedAPITest > $CLIENT_LOG 2>&1 +if [ $? -ne 0 ]; then + cat $CLIENT_LOG + echo -e "\n***\n*** Python HTTP Restricted Protocol Test Failed\n***" + RET=1 +fi +set -e +kill $SERVER_PID +wait $SERVER_PID + ### if [ $RET -eq 0 ]; then diff --git a/src/command_line_parser.cc b/src/command_line_parser.cc index 5ec18841dc..0907982d6b 100644 --- a/src/command_line_parser.cc +++ b/src/command_line_parser.cc @@ -246,7 +246,6 @@ ParsePairOption(const std::string& arg, const std::string& delim_str) return {ParseOption(first_string), ParseOption(second_string)}; } -#ifdef TRITON_ENABLE_GRPC // Split 'options' by 'delim_str' and place split strings into a vector std::vector SplitOptions(std::string options, const std::string& delim_str) @@ -263,7 +262,6 @@ SplitOptions(std::string options, const std::string& delim_str) res.emplace_back(options); return res; } -#endif // TRITON_ENABLE_GRPC } // namespace @@ -290,6 +288,7 @@ enum TritonOptionId { OPTION_REUSE_HTTP_PORT, OPTION_HTTP_ADDRESS, OPTION_HTTP_THREAD_COUNT, + OPTION_HTTP_RESTRICTED_API, #endif // TRITON_ENABLE_HTTP #if defined(TRITON_ENABLE_GRPC) OPTION_ALLOW_GRPC, @@ -463,6 +462,16 @@ TritonParser::SetupOptions() http_options_.push_back( {OPTION_HTTP_THREAD_COUNT, "http-thread-count", Option::ArgInt, "Number of threads handling HTTP requests."}); + http_options_.push_back( + {OPTION_HTTP_RESTRICTED_API, "http-restricted-api", + ":=", + "Specify restricted HTTP api setting. The format of this " + "flag is --http-restricted-api=,=. Where " + " is a comma-separated list of apis to be restricted. " + " will be additional header key to be checked when a HTTP request " + "is received, and is the value expected to be matched." + " Allowed APIs: " + + Join(RESTRICTED_CATEGORY_NAMES, ", ")}); #endif // TRITON_ENABLE_HTTP #if defined(TRITON_ENABLE_GRPC) @@ -565,7 +574,9 @@ TritonParser::SetupOptions() "flag is --grpc-restricted-protocol=,=. Where " " is a comma-separated list of protocols to be restricted. " " will be additional header key to be checked when a GRPC request " - "is received, and is the value expected to be matched."}); + "is received, and is the value expected to be matched." + " Allowed protocols: " + + Join(RESTRICTED_CATEGORY_NAMES, ", ")}); #endif // TRITON_ENABLE_GRPC #ifdef TRITON_ENABLE_LOGGING @@ -1288,6 +1299,12 @@ TritonParser::Parse(int argc, char** argv) case OPTION_HTTP_THREAD_COUNT: lparams.http_thread_cnt_ = ParseOption(optarg); break; + case OPTION_HTTP_RESTRICTED_API: + ParseRestrictedFeatureOption( + optarg, long_options[option_index].name, "", "api", + lparams.http_restricted_apis_); + break; + #endif // TRITON_ENABLE_HTTP #ifdef TRITON_ENABLE_SAGEMAKER @@ -1368,7 +1385,8 @@ TritonParser::Parse(int argc, char** argv) lgrpc_options.infer_compression_level_ = GRPC_COMPRESS_LEVEL_HIGH; } else { throw ParseException( - "invalid argument for --grpc_infer_response_compression_level"); + "invalid argument for " + "--grpc_infer_response_compression_level"); } break; } @@ -1398,16 +1416,11 @@ TritonParser::Parse(int argc, char** argv) ParseOption(optarg); break; case OPTION_GRPC_RESTRICTED_PROTOCOL: { - const auto& parsed_tuple = ParseGrpcRestrictedProtocolOption(optarg); - const auto& protocols = SplitOptions(std::get<0>(parsed_tuple), ","); - const auto& key = std::get<1>(parsed_tuple); - const auto& value = std::get<2>(parsed_tuple); - grpc::ProtocolGroup pg; - for (const auto& p : protocols) { - pg.protocols_.emplace(p); - } - pg.restricted_key_ = std::make_pair(key, value); - lgrpc_options.protocol_groups_.emplace_back(pg); + ParseRestrictedFeatureOption( + optarg, long_options[option_index].name, + std::string( + triton::server::grpc::kRestrictedProtocolHeaderTemplate), + "protocol", lgrpc_options.restricted_protocols_); break; } case OPTION_GRPC_HEADER_FORWARD_PATTERN: @@ -1759,7 +1772,8 @@ TritonParser::ParseMetricsConfigOption(const std::string& arg) int delim_name = name_substr.find(","); // No name-specific configs currently supported, though it may be in - // the future. Map global configs to empty string like other configs for now. + // the future. Map global configs to empty string like other configs for + // now. std::string name_string = std::string(); if (delim_name >= 0) { std::stringstream ss; @@ -1830,7 +1844,8 @@ TritonParser::ParseRateLimiterResourceOption(const std::string& arg) { std::string error_string( "--rate-limit-resource option format is " - "'::' or ':'. Got " + + "'::' or ':'. " + "Got " + arg); std::string name_string(""); @@ -1904,57 +1919,64 @@ TritonParser::ParseBackendConfigOption(const std::string& arg) return {name_string, setting_string, value_string}; } -std::tuple -TritonParser::ParseGrpcRestrictedProtocolOption(const std::string& arg) +void +TritonParser::ParseRestrictedFeatureOption( + const std::string& arg, const std::string& option_name, + const std::string& key_prefix, const std::string& feature_type, + RestrictedFeatures& restricted_features) { - try { - return ParseGenericConfigOption(arg, ":", "="); - } - catch (const ParseException& pe) { - // catch and throw exception with option specific message - std::stringstream ss; - ss << "--grpc-restricted-protocol option format is ':='. Got " - << arg << std::endl; - throw ParseException(ss.str()); + const auto& parsed_tuple = + ParseGenericConfigOption(arg, ":", "=", option_name, "config name"); + + const auto& features = SplitOptions(std::get<0>(parsed_tuple), ","); + const auto& key = std::get<1>(parsed_tuple); + const auto& value = std::get<2>(parsed_tuple); + + for (const auto& feature : features) { + const auto& category = RestrictedFeatures::ToCategory(feature); + + if (category == RestrictedCategory::INVALID) { + std::stringstream ss; + ss << "unknown restricted " << feature_type << " '" << feature << "' " + << std::endl; + throw ParseException(ss.str()); + } + + if (restricted_features.IsRestricted(category)) { + // restricted feature can only be in one group + std::stringstream ss; + ss << "restricted " << feature_type << " '" << feature + << "' can not be specified in multiple config groups" << std::endl; + throw ParseException(ss.str()); + } + restricted_features.Insert( + category, std::make_pair(key_prefix + key, value)); } - // Should not reach here - return {}; } std::tuple TritonParser::ParseHostPolicyOption(const std::string& arg) { - try { - return ParseGenericConfigOption(arg, ",", "="); - } - catch (const ParseException& pe) { - // catch and throw exception with option specific message - std::stringstream ss; - ss << "--host-policy option format is ',='. Got " - << arg << std::endl; - throw ParseException(ss.str()); - } - // Should not reach here - return {}; + return ParseGenericConfigOption(arg, ",", "=", "host-policy", "policy name"); } std::tuple TritonParser::ParseGenericConfigOption( const std::string& arg, const std::string& first_delim, - const std::string& second_delim) + const std::string& second_delim, const std::string& option_name, + const std::string& config_name) { // Format is ",=" int delim_name = arg.find(first_delim); int delim_setting = arg.find(second_delim, delim_name + 1); + std::string error_string = "--" + option_name + " option format is '<" + + config_name + ">" + first_delim + "" + + second_delim + "'. Got " + arg + "\n"; + // Check for 2 semicolons if ((delim_name < 0) || (delim_setting < 0)) { - std::stringstream ss; - ss << "option format is '" << first_delim << "" - << second_delim << "'. Got " << arg << std::endl; - throw ParseException(ss.str()); + throw ParseException(error_string); } std::string name_string = arg.substr(0, delim_name); @@ -1963,10 +1985,7 @@ TritonParser::ParseGenericConfigOption( std::string value_string = arg.substr(delim_setting + 1); if (name_string.empty() || setting_string.empty() || value_string.empty()) { - std::stringstream ss; - ss << "option format is '" << first_delim << "" - << second_delim << "'. Got " << arg << std::endl; - throw ParseException(ss.str()); + throw ParseException(error_string); } return {name_string, setting_string, value_string}; @@ -2177,5 +2196,4 @@ TritonParser::PostProcessTraceArgs( } #endif // TRITON_ENABLE_TRACING - -}} // namespace triton::server +}} // namespace triton::server diff --git a/src/command_line_parser.h b/src/command_line_parser.h index efffcff94e..8a34babc98 100644 --- a/src/command_line_parser.h +++ b/src/command_line_parser.h @@ -35,6 +35,7 @@ #include #include +#include "restricted_features.h" #include "triton/common/logging.h" #include "triton/core/tritonserver.h" #ifdef TRITON_ENABLE_GRPC @@ -189,6 +190,7 @@ struct TritonServerParameters { std::string http_forward_header_pattern_; // The number of threads to initialize for the HTTP front-end. int http_thread_cnt_{8}; + RestrictedFeatures http_restricted_apis_{}; #endif // TRITON_ENABLE_HTTP #ifdef TRITON_ENABLE_GRPC @@ -281,8 +283,10 @@ class TritonParser { const std::string& arg); std::tuple ParseMetricsConfigOption( const std::string& arg); - std::tuple - ParseGrpcRestrictedProtocolOption(const std::string& arg); + void ParseRestrictedFeatureOption( + const std::string& arg, const std::string& option_name, + const std::string& header_prefix, const std::string& feature_type, + RestrictedFeatures& restricted_features); #ifdef TRITON_ENABLE_TRACING TRITONSERVER_InferenceTraceLevel ParseTraceLevelOption(std::string arg); InferenceTraceMode ParseTraceModeOption(std::string arg); @@ -308,7 +312,8 @@ class TritonParser { // "[1st_delim][2nd_delim]" format std::tuple ParseGenericConfigOption( const std::string& arg, const std::string& first_delim, - const std::string& second_delim); + const std::string& second_delim, const std::string& option_name, + const std::string& config_name); // Initialize individual option groups void SetupOptions(); diff --git a/src/common.cc b/src/common.cc index 83fe3c6c25..289d868866 100644 --- a/src/common.cc +++ b/src/common.cc @@ -29,6 +29,7 @@ #include #include +#include "restricted_features.h" #include "triton/core/tritonserver.h" namespace triton { namespace server { @@ -101,14 +102,4 @@ Contains(const std::vector& vec, const std::string& str) return std::find(vec.begin(), vec.end(), str) != vec.end(); } -std::string -Join(const std::vector& vec, const std::string& delim) -{ - std::stringstream ss; - std::copy( - vec.begin(), vec.end(), - std::ostream_iterator(ss, delim.c_str())); - return ss.str(); -} - }} // namespace triton::server diff --git a/src/common.h b/src/common.h index c11254a6cc..aa160f394f 100644 --- a/src/common.h +++ b/src/common.h @@ -163,11 +163,25 @@ int64_t GetElementCount(const std::vector& dims); /// \return True if the str is found, false otherwise. bool Contains(const std::vector& vec, const std::string& str); -/// Joins vector of strings 'vec' into a single string delimited by 'delim'. +/// Joins container of strings into a single string delimited by +/// 'delim'. /// -/// \param vec The vector of strings to join. +/// \param container The container of strings to join. /// \param delim The delimiter to join with. /// \return The joint string. -std::string Join(const std::vector& vec, const std::string& delim); +template +std::string +Join(const T& container, const std::string& delim) +{ + if (container.empty()) { + return ""; + } + std::stringstream ss; + ss << container[0]; + for (size_t i = 1; i < container.size(); ++i) { + ss << delim << container[i]; + } + return ss.str(); +} }} // namespace triton::server diff --git a/src/grpc/grpc_server.cc b/src/grpc/grpc_server.cc index dbc0f85559..f9dbd5c016 100644 --- a/src/grpc/grpc_server.cc +++ b/src/grpc/grpc_server.cc @@ -253,8 +253,7 @@ class CommonHandler : public HandlerBase { inference::GRPCInferenceService::AsyncService* service, ::grpc::health::v1::Health::AsyncService* health_service, ::grpc::ServerCompletionQueue* cq, - std::map> - restricted_keys); + const RestrictedFeatures& restricted_keys); // Descriptive name of of the handler. const std::string& Name() const { return name_; } @@ -299,13 +298,9 @@ class CommonHandler : public HandlerBase { ::grpc::health::v1::Health::AsyncService* health_service_; ::grpc::ServerCompletionQueue* cq_; std::unique_ptr thread_; - std::map> restricted_keys_; - static std::pair empty_restricted_key_; + const RestrictedFeatures& restricted_keys_; }; -std::pair CommonHandler::empty_restricted_key_{ - "", ""}; - CommonHandler::CommonHandler( const std::string& name, const std::shared_ptr& tritonserver, @@ -314,7 +309,7 @@ CommonHandler::CommonHandler( inference::GRPCInferenceService::AsyncService* service, ::grpc::health::v1::Health::AsyncService* health_service, ::grpc::ServerCompletionQueue* cq, - std::map> restricted_keys) + const RestrictedFeatures& restricted_keys) : name_(name), tritonserver_(tritonserver), shm_manager_(shm_manager), trace_manager_(trace_manager), service_(service), health_service_(health_service), cq_(cq), @@ -439,9 +434,8 @@ CommonHandler::RegisterServerLive() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("health"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ServerLiveRequest, inference::ServerLiveResponse>( @@ -476,9 +470,8 @@ CommonHandler::RegisterServerReady() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("health"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ServerReadyRequest, inference::ServerReadyResponse>( @@ -524,9 +517,8 @@ CommonHandler::RegisterHealthCheck() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("health"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); new CommonCallData< ::grpc::ServerAsyncResponseWriter< ::grpc::health::v1::HealthCheckResponse>, @@ -569,9 +561,8 @@ CommonHandler::RegisterModelReady() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("health"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::HEALTH); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ModelReadyRequest, inference::ModelReadyResponse>( @@ -648,9 +639,8 @@ CommonHandler::RegisterServerMetadata() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("metadata"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ServerMetadataRequest, inference::ServerMetadataResponse>( @@ -817,9 +807,8 @@ CommonHandler::RegisterModelMetadata() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("metadata"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::METADATA); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ModelMetadataRequest, inference::ModelMetadataResponse>( @@ -871,9 +860,8 @@ CommonHandler::RegisterModelConfig() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("model-config"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_CONFIG); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ModelConfigRequest, inference::ModelConfigResponse>( @@ -1202,9 +1190,8 @@ CommonHandler::RegisterModelStatistics() #endif }; - const auto it = restricted_keys_.find("statistics"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::STATISTICS); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::ModelStatisticsRequest, inference::ModelStatisticsResponse>( @@ -1478,9 +1465,8 @@ CommonHandler::RegisterTrace() #endif }; - const auto it = restricted_keys_.find("trace"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::TRACE); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::TraceSettingRequest, inference::TraceSettingResponse>( @@ -1688,9 +1674,8 @@ CommonHandler::RegisterLogging() #endif }; - const auto it = restricted_keys_.find("logging"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::LOGGING); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::LogSettingsRequest, inference::LogSettingsResponse>( @@ -1761,9 +1746,8 @@ CommonHandler::RegisterSystemSharedMemoryStatus() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryStatusResponse>, @@ -1800,9 +1784,8 @@ CommonHandler::RegisterSystemSharedMemoryRegister() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryRegisterResponse>, @@ -1844,9 +1827,8 @@ CommonHandler::RegisterSystemSharedMemoryUnregister() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::SystemSharedMemoryUnregisterResponse>, @@ -1912,9 +1894,8 @@ CommonHandler::RegisterCudaSharedMemoryStatus() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryStatusResponse>, @@ -1963,9 +1944,8 @@ CommonHandler::RegisterCudaSharedMemoryRegister() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryRegisterResponse>, @@ -2005,10 +1985,9 @@ CommonHandler::RegisterCudaSharedMemoryUnregister() GrpcStatusUtil::Create(status, err); TRITONSERVER_ErrorDelete(err); }; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::SHARED_MEMORY); - const auto it = restricted_keys_.find("shared-memory"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::CudaSharedMemoryUnregisterResponse>, @@ -2112,9 +2091,8 @@ CommonHandler::RegisterRepositoryIndex() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("model-repository"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::RepositoryIndexRequest, inference::RepositoryIndexResponse>( @@ -2224,9 +2202,8 @@ CommonHandler::RegisterRepositoryModelLoad() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("model-repository"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter, inference::RepositoryModelLoadRequest, @@ -2293,9 +2270,8 @@ CommonHandler::RegisterRepositoryModelUnload() TRITONSERVER_ErrorDelete(err); }; - const auto it = restricted_keys_.find("model-repository"); - std::pair restricted_kv = - (it == restricted_keys_.end()) ? empty_restricted_key_ : it->second; + const std::pair& restricted_kv = + restricted_keys_.Get(RestrictedCategory::MODEL_REPOSITORY); new CommonCallData< ::grpc::ServerAsyncResponseWriter< inference::RepositoryModelUnloadResponse>, @@ -2411,35 +2387,15 @@ Server::Server( model_infer_cq_ = builder_.AddCompletionQueue(); model_stream_infer_cq_ = builder_.AddCompletionQueue(); - // Read and set restriction for each protocol specified - // map from protocol name to a pair of header to look for and the key - std::map> restricted_keys; - for (const auto& pg : options.protocol_groups_) { - for (const auto& p : pg.protocols_) { - if (restricted_keys.find(p) != restricted_keys.end()) { - throw std::invalid_argument( - std::string("protocol '") + p + - "' can not be " - "specified in multiple config group"); - } - const auto header = std::string(kRestrictedProtocolHeaderTemplate) + - pg.restricted_key_.first; - restricted_keys[p] = std::make_pair(header, pg.restricted_key_.second); - } - } - // A common Handler for other non-inference requests common_handler_.reset(new CommonHandler( "CommonHandler", tritonserver_, shm_manager_, trace_manager_, &service_, - &health_service_, common_cq_.get(), restricted_keys)); + &health_service_, common_cq_.get(), options.restricted_protocols_)); // [FIXME] "register" logic is different for infer // Handler for model inference requests. - const auto it = restricted_keys.find("inference"); std::pair restricted_kv = - (it == restricted_keys.end()) - ? std::pair{"", ""} - : it->second; + options.restricted_protocols_.Get(RestrictedCategory::INFERENCE); for (int i = 0; i < REGISTER_GRPC_INFER_THREAD_COUNT; ++i) { model_infer_handlers_.emplace_back(new ModelInferHandler( "ModelInferHandler", tritonserver_, trace_manager_, shm_manager_, diff --git a/src/grpc/grpc_server.h b/src/grpc/grpc_server.h index 4bbb54594f..197cb72eea 100644 --- a/src/grpc/grpc_server.h +++ b/src/grpc/grpc_server.h @@ -29,6 +29,7 @@ #include +#include "../restricted_features.h" #include "../shared_memory_manager.h" #include "../tracer.h" #include "grpc_handler.h" @@ -74,12 +75,6 @@ struct KeepAliveOptions { int http2_max_ping_strikes_{2}; }; -struct ProtocolGroup { - std::string name_{""}; - std::set protocols_{}; - std::pair restricted_key_{"", ""}; -}; - struct Options { SocketOptions socket_; SslOptions ssl_; @@ -90,7 +85,7 @@ struct Options { // requests doesn't exceed this value there will be no // allocation/deallocation of request/response objects. int infer_allocation_pool_size_{8}; - std::vector protocol_groups_{}; + RestrictedFeatures restricted_protocols_; std::string forward_header_pattern_; }; diff --git a/src/http_server.cc b/src/http_server.cc index 83051bb4a1..7f74d1c373 100644 --- a/src/http_server.cc +++ b/src/http_server.cc @@ -80,6 +80,18 @@ namespace triton { namespace server { return; \ } while (false) +#define RETURN_AND_RESPOND_IF_RESTRICTED( \ + REQ, RESTRICTED_CATEGORY, RESTRICTED_APIS) \ + do { \ + auto const& is_restricted_api = \ + RESTRICTED_APIS.IsRestricted(RESTRICTED_CATEGORY); \ + auto const& restriction = RESTRICTED_APIS.Get(RESTRICTED_CATEGORY); \ + if (is_restricted_api && RespondIfRestricted(REQ, restriction)) { \ + return; \ + } \ + } while (false) + + namespace { void EVBufferAddErrorJson(evbuffer* buffer, const char* message) @@ -1038,7 +1050,8 @@ HTTPAPIServer::HTTPAPIServer( triton::server::TraceManager* trace_manager, const std::shared_ptr& shm_manager, const int32_t port, const bool reuse_port, const std::string& address, - const std::string& header_forward_pattern, const int thread_cnt) + const std::string& header_forward_pattern, const int thread_cnt, + const RestrictedFeatures& restricted_apis) : HTTPServer(port, reuse_port, address, header_forward_pattern, thread_cnt), server_(server), trace_manager_(trace_manager), shm_manager_(shm_manager), allocator_(nullptr), server_regex_(R"(/v2(?:/health/(live|ready))?)"), @@ -1050,7 +1063,7 @@ HTTPAPIServer::HTTPAPIServer( R"(/v2/systemsharedmemory(?:/region/([^/]+))?/(status|register|unregister))"), cudasharedmemory_regex_( R"(/v2/cudasharedmemory(?:/region/([^/]+))?/(status|register|unregister))"), - trace_regex_(R"(/v2/trace/setting)") + trace_regex_(R"(/v2/trace/setting)"), restricted_apis_(restricted_apis) { // FIXME, don't cache server metadata. The http endpoint should // not be deciding that server metadata will not change during @@ -1271,6 +1284,9 @@ HTTPAPIServer::InferResponseFree( void HTTPAPIServer::HandleServerHealth(evhtp_request_t* req, const std::string& kind) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::HEALTH, restricted_apis_); + if (req->method != htp_method_GET) { RETURN_AND_RESPOND_WITH_ERR( req, EVHTP_RES_METHNALLOWED, "Method Not Allowed"); @@ -1295,6 +1311,9 @@ void HTTPAPIServer::HandleRepositoryIndex( evhtp_request_t* req, const std::string& repository_name) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::MODEL_REPOSITORY, restricted_apis_); + AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_POST) { RETURN_AND_RESPOND_WITH_ERR( @@ -1362,6 +1381,9 @@ HTTPAPIServer::HandleRepositoryControl( evhtp_request_t* req, const std::string& repository_name, const std::string& model_name, const std::string& action) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::MODEL_REPOSITORY, restricted_apis_); + AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_POST) { RETURN_AND_RESPOND_WITH_ERR( @@ -1515,6 +1537,9 @@ HTTPAPIServer::HandleModelReady( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::HEALTH, restricted_apis_); + if (req->method != htp_method_GET) { RETURN_AND_RESPOND_WITH_ERR( req, EVHTP_RES_METHNALLOWED, "Method Not Allowed"); @@ -1549,7 +1574,11 @@ HTTPAPIServer::HandleModelMetadata( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::METADATA, restricted_apis_); + AddContentTypeHeader(req, "application/json"); + if (req->method != htp_method_GET) { RETURN_AND_RESPOND_WITH_ERR( req, EVHTP_RES_METHNALLOWED, "Method Not Allowed"); @@ -1618,6 +1647,9 @@ HTTPAPIServer::HandleModelConfig( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::MODEL_CONFIG, restricted_apis_); + AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_GET) { RETURN_AND_RESPOND_WITH_ERR( @@ -1643,6 +1675,9 @@ HTTPAPIServer::HandleModelStats( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::STATISTICS, restricted_apis_); + AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_GET) { RETURN_AND_RESPOND_WITH_ERR( @@ -1685,6 +1720,9 @@ HTTPAPIServer::HandleModelStats( void HTTPAPIServer::HandleTrace(evhtp_request_t* req, const std::string& model_name) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::TRACE, restricted_apis_); + AddContentTypeHeader(req, "application/json"); if ((req->method != htp_method_GET) && (req->method != htp_method_POST)) { RETURN_AND_RESPOND_WITH_ERR( @@ -1938,6 +1976,9 @@ HTTPAPIServer::HandleTrace(evhtp_request_t* req, const std::string& model_name) void HTTPAPIServer::HandleLogging(evhtp_request_t* req) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::LOGGING, restricted_apis_); + AddContentTypeHeader(req, "application/json"); if ((req->method != htp_method_GET) && (req->method != htp_method_POST)) { RETURN_AND_RESPOND_WITH_ERR( @@ -2087,6 +2128,9 @@ HTTPAPIServer::HandleLogging(evhtp_request_t* req) void HTTPAPIServer::HandleServerMetadata(evhtp_request_t* req) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::METADATA, restricted_apis_); + AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_GET) { RETURN_AND_RESPOND_WITH_ERR( @@ -2108,6 +2152,9 @@ HTTPAPIServer::HandleSystemSharedMemory( evhtp_request_t* req, const std::string& region_name, const std::string& action) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::SHARED_MEMORY, restricted_apis_); + AddContentTypeHeader(req, "application/json"); if ((action == "status") && (req->method != htp_method_GET)) { RETURN_AND_RESPOND_WITH_ERR( @@ -2211,6 +2258,9 @@ HTTPAPIServer::HandleCudaSharedMemory( evhtp_request_t* req, const std::string& region_name, const std::string& action) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::SHARED_MEMORY, restricted_apis_); + AddContentTypeHeader(req, "application/json"); if ((action == "status") && (req->method != htp_method_GET)) { RETURN_AND_RESPOND_WITH_ERR( @@ -3063,6 +3113,9 @@ HTTPAPIServer::HandleGenerate( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str, bool streaming) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::INFERENCE, restricted_apis_); + AddContentTypeHeader(req, "application/json"); if (req->method != htp_method_POST) { RETURN_AND_RESPOND_WITH_ERR( @@ -3411,6 +3464,9 @@ HTTPAPIServer::HandleInfer( evhtp_request_t* req, const std::string& model_name, const std::string& model_version_str) { + RETURN_AND_RESPOND_IF_RESTRICTED( + req, RestrictedCategory::INFERENCE, restricted_apis_); + if (req->method != htp_method_POST) { RETURN_AND_RESPOND_WITH_ERR( req, EVHTP_RES_METHNALLOWED, "Method Not Allowed"); @@ -4487,11 +4543,12 @@ HTTPAPIServer::Create( const std::shared_ptr& shm_manager, const int32_t port, const bool reuse_port, const std::string& address, const std::string& header_forward_pattern, const int thread_cnt, + const RestrictedFeatures& restricted_features, std::unique_ptr* http_server) { http_server->reset(new HTTPAPIServer( server, trace_manager, shm_manager, port, reuse_port, address, - header_forward_pattern, thread_cnt)); + header_forward_pattern, thread_cnt, restricted_features)); const std::string addr = address + ":" + std::to_string(port); LOG_INFO << "Started HTTPService at " << addr; @@ -4499,4 +4556,22 @@ HTTPAPIServer::Create( return nullptr; } +bool +HTTPAPIServer::RespondIfRestricted( + evhtp_request_t* req, const Restriction& restriction) +{ + auto header = restriction.first; + auto expected_value = restriction.second; + const char* actual_value = evhtp_kv_find(req->headers_in, header.c_str()); + if ((actual_value == nullptr) || (actual_value != expected_value)) { + EVBufferAddErrorJson( + req->buffer_out, + std::string("This API is restricted, expecting header '" + header + "'") + .c_str()); + evhtp_send_reply(req, EVHTP_RES_FORBIDDEN); + return true; + } + return false; +} + }} // namespace triton::server diff --git a/src/http_server.h b/src/http_server.h index 1b17440674..53c5daea87 100644 --- a/src/http_server.h +++ b/src/http_server.h @@ -38,6 +38,7 @@ #include "common.h" #include "data_compressor.h" +#include "restricted_features.h" #include "shared_memory_manager.h" #include "tracer.h" #include "triton/common/logging.h" @@ -151,6 +152,7 @@ class HTTPAPIServer : public HTTPServer { const std::shared_ptr& smb_manager, const int32_t port, const bool reuse_port, const std::string& address, const std::string& header_forward_pattern, const int thread_cnt, + const RestrictedFeatures& restricted_apis, std::unique_ptr* http_server); virtual ~HTTPAPIServer(); @@ -358,7 +360,8 @@ class HTTPAPIServer : public HTTPServer { triton::server::TraceManager* trace_manager, const std::shared_ptr& shm_manager, const int32_t port, const bool reuse_port, const std::string& address, - const std::string& header_forward_pattern, const int thread_cnt); + const std::string& header_forward_pattern, const int thread_cnt, + const RestrictedFeatures& restricted_apis); virtual void Handle(evhtp_request_t* req) override; // [FIXME] extract to "infer" class virtual std::unique_ptr CreateInferRequest( @@ -543,6 +546,9 @@ class HTTPAPIServer : public HTTPServer { parameters_field, new MappingSchema(MappingSchema::Kind::MAPPING_SCHEMA, true)); } + const RestrictedFeatures& restricted_apis_{}; + bool RespondIfRestricted( + evhtp_request_t* req, const Restriction& restriction); }; }} // namespace triton::server diff --git a/src/main.cc b/src/main.cc index bda1d66764..14fde049c3 100644 --- a/src/main.cc +++ b/src/main.cc @@ -138,7 +138,8 @@ StartHttpService( server, trace_manager, shm_manager, g_triton_params.http_port_, g_triton_params.reuse_http_port_, g_triton_params.http_address_, g_triton_params.http_forward_header_pattern_, - g_triton_params.http_thread_cnt_, service); + g_triton_params.http_thread_cnt_, g_triton_params.http_restricted_apis_, + service); if (err == nullptr) { err = (*service)->Start(); } diff --git a/src/restricted_features.h b/src/restricted_features.h new file mode 100644 index 0000000000..1b366e8ec4 --- /dev/null +++ b/src/restricted_features.h @@ -0,0 +1,114 @@ +// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions +// are met: +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// * Neither the name of NVIDIA CORPORATION nor the names of its +// contributors may be used to endorse or promote products derived +// from this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#pragma once + +#include +#include +#include + +namespace triton { namespace server { + +/// Header and Value pair for a restricted feature +using Restriction = std::pair; + +/// Restricted Categories +enum RestrictedCategory : uint8_t { + HEALTH, + METADATA, + INFERENCE, + SHARED_MEMORY, + MODEL_CONFIG, + MODEL_REPOSITORY, + STATISTICS, + TRACE, + LOGGING, + INVALID, + CATEGORY_COUNT = INVALID +}; + +/// Restricted Category Names +const std::array + RESTRICTED_CATEGORY_NAMES{ + "health", "metadata", "inference", + "shared-memory", "model-config", "model-repository", + "statistics", "trace", "logging"}; + +/// Collection of restricted features +/// +/// Initially empty and all categories unrestricted +class RestrictedFeatures { + public: + /// Returns RestrictedCategory enum from category name + /// + /// \param[in] category category name + /// \return category enum returns INVALID if unknown + static RestrictedCategory ToCategory(const std::string& category) + { + const auto found = std::find( + begin(RESTRICTED_CATEGORY_NAMES), end(RESTRICTED_CATEGORY_NAMES), + category); + const auto offset = std::distance(begin(RESTRICTED_CATEGORY_NAMES), found); + return RestrictedCategory(offset); + } + + /// Insert restriction for given category + /// + /// \param[in] category category to restrict + /// \param[in] restriction header, value pair + void Insert(const RestrictedCategory& category, Restriction&& restriction) + { + restrictions_[category] = std::move(restriction); + restricted_categories_[category] = true; + } + + /// Get header,value pair for restricted category + /// + /// \param[in] category category to restrict + /// \return restriction header, value pair + const Restriction& Get(RestrictedCategory category) const + { + return restrictions_[category]; + } + + /// Return true if a category is restricted + /// + /// \param[in] category category to restrict + /// \return true if category is restricted, false otherwise + + const bool& IsRestricted(RestrictedCategory category) const + { + return restricted_categories_[category]; + } + + RestrictedFeatures() = default; + ~RestrictedFeatures() = default; + + private: + std::array restrictions_{}; + + std::array restricted_categories_{}; +}; +}} // namespace triton::server diff --git a/src/sagemaker_server.h b/src/sagemaker_server.h index 2b8189be86..7f1ef7e050 100644 --- a/src/sagemaker_server.h +++ b/src/sagemaker_server.h @@ -72,7 +72,8 @@ class SagemakerAPIServer : public HTTPAPIServer { const int32_t port, const std::string address, const int thread_cnt) : HTTPAPIServer( server, trace_manager, shm_manager, port, false /* reuse_port */, - address, "" /* header_forward_pattern */, thread_cnt), + address, "" /* header_forward_pattern */, thread_cnt, + RestrictedFeatures()), ping_regex_(R"(/ping)"), invocations_regex_(R"(/invocations)"), models_regex_(R"(/models(?:/)?([^/]+)?(/invoke)?)"), model_path_regex_( diff --git a/src/vertex_ai_server.cc b/src/vertex_ai_server.cc index f14143f5b6..626a210561 100644 --- a/src/vertex_ai_server.cc +++ b/src/vertex_ai_server.cc @@ -1,4 +1,4 @@ -// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions @@ -45,7 +45,8 @@ VertexAiAPIServer::VertexAiAPIServer( const std::string& default_model_name) : HTTPAPIServer( server, trace_manager, shm_manager, port, false /* reuse_port */, - address, "" /* header_forward_pattern */, thread_cnt), + address, "" /* header_forward_pattern */, thread_cnt, + RestrictedFeatures()), prediction_regex_(prediction_route), health_regex_(health_route), health_mode_("ready"), model_name_(default_model_name), model_version_str_("")