Skip to content

Commit

Permalink
updated with restricted features class and additional tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nnshah1 committed Oct 31, 2023
1 parent a276271 commit a64ce5d
Show file tree
Hide file tree
Showing 12 changed files with 158 additions and 114 deletions.
17 changes: 17 additions & 0 deletions qa/L0_grpc/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,23 @@ elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then
RET=1
fi

# Unknown protocol, not allowed
SERVER_ARGS="--model-repository=${MODELDIR} \
--grpc-restricted-protocol=model-reposit,health:k1=v1 \
--grpc-restricted-protocol=metadata,health:k2=v2"
run_server
EXPECTED_MSG="unknown restricted protocol 'model-reposit'"
if [ "$SERVER_PID" != "0" ]; then
echo -e "\n***\n*** Expect fail to start $SERVER\n***"
kill $SERVER_PID
wait $SERVER_PID
RET=1
elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then
echo -e "\n***\n*** Failed. Expected ${EXPECTED_MSG} to be found in log\n***"
cat $SERVER_LOG
RET=1
fi

# Test restricted protocols
SERVER_ARGS="--model-repository=${MODELDIR} \
--grpc-restricted-protocol=model-repository:admin-key=admin-value \
Expand Down
9 changes: 7 additions & 2 deletions qa/L0_http/http_restricted_api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ def test_sanity(self):
"simple", headers={"infer-key": "infer-value"}
)

# health, infer, model repository APIs are restricted.
# health and infer expects "infer-key : infer-value" header,
# metadata, infer, model repository APIs are restricted.
# metadata and infer expects "infer-key : infer-value" header,
# model repository expected "admin-key : admin-value".
def test_model_repository(self):
with self.assertRaisesRegex(InferenceServerException, "This API is restricted"):
Expand All @@ -64,6 +64,11 @@ def test_model_repository(self):
self.model_name_, headers={"admin-key": "admin-value"}
)

def test_metadata(self):
with self.assertRaisesRegex(InferenceServerException, "This API is restricted"):
self.client_.get_server_metadata()
self.client_.get_server_metadata({"infer-key": "infer-value"})

def test_infer(self):
# setup
inputs = [
Expand Down
25 changes: 24 additions & 1 deletion qa/L0_http/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,8 @@ kill $SERVER_PID
wait $SERVER_PID

### Test Restricted APIs ###
### Repeated API not allowed

MODELDIR="`pwd`/models"
SERVER_ARGS="--model-repository=${MODELDIR}
--http-restricted-api=model-repository,health:k1=v1 \
Expand All @@ -688,11 +690,32 @@ elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then
RET=1
fi

### Test Unknown Restricted API###
### Unknown API not allowed

MODELDIR="`pwd`/models"
SERVER_ARGS="--model-repository=${MODELDIR}
--http-restricted-api=model-reposit,health:k1=v1 \
--http-restricted-api=metadata,health:k2=v2"
run_server
EXPECTED_MSG="unknown restricted api 'model-reposit'"
if [ "$SERVER_PID" != "0" ]; then
echo -e "\n***\n*** Expect fail to start $SERVER\n***"
kill $SERVER_PID
wait $SERVER_PID
RET=1
elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then
echo -e "\n***\n*** Failed. Expected ${EXPECTED_MSG} to be found in log\n***"
cat $SERVER_LOG
RET=1
fi

### Test Restricted APIs ###
### Restricted model-repository, metadata, and inference

SERVER_ARGS="--model-repository=${MODELDIR} \
--http-restricted-api=model-repository:admin-key=admin-value \
--http-restricted-api=inference:infer-key=infer-value"
--http-restricted-api=inference,metadata:infer-key=infer-value"
run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
Expand Down
16 changes: 13 additions & 3 deletions src/command_line_parser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,7 @@ void
TritonParser::ParseRestrictedFeatureOption(
const std::string& arg, const std::string& option_name,
const std::string& key_prefix, const std::string& feature_type,
RestrictedFeatureMap& restricted_features)
RestrictedFeatures& restricted_features)
{
const auto& parsed_tuple =
ParseGenericConfigOption(arg, ":", "=", option_name, "config name");
Expand All @@ -1929,14 +1929,24 @@ TritonParser::ParseRestrictedFeatureOption(
const auto& value = std::get<2>(parsed_tuple);

for (const auto& feature : features) {
if (restricted_features.count(feature)) {
const auto& category = RestrictedFeatures::ToCategory(feature);

if (category == RestrictedCategory::INVALID) {
std::stringstream ss;
ss << "unknown restricted " << feature_type << " '" << feature << "' "
<< std::endl;
throw ParseException(ss.str());
}

if (restricted_features.IsRestricted(category)) {
// restricted feature can only be in one group
std::stringstream ss;
ss << "restricted " << feature_type << " '" << feature
<< "' can not be specified in multiple config groups" << std::endl;
throw ParseException(ss.str());
}
restricted_features[feature] = std::make_pair(key_prefix + key, value);
restricted_features.Insert(
category, std::make_pair(key_prefix + key, value));
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/command_line_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <unordered_map>
#include <vector>

#include "restricted_features.h"
#include "triton/common/logging.h"
#include "triton/core/tritonserver.h"
#ifdef TRITON_ENABLE_GRPC
Expand Down Expand Up @@ -189,7 +190,7 @@ struct TritonServerParameters {
std::string http_forward_header_pattern_;
// The number of threads to initialize for the HTTP front-end.
int http_thread_cnt_{8};
RestrictedFeatureMap http_restricted_apis_{};
RestrictedFeatures http_restricted_apis_{};
#endif // TRITON_ENABLE_HTTP

#ifdef TRITON_ENABLE_GRPC
Expand Down Expand Up @@ -285,7 +286,7 @@ class TritonParser {
void ParseRestrictedFeatureOption(
const std::string& arg, const std::string& option_name,
const std::string& header_prefix, const std::string& feature_name,
RestrictedFeatureMap& restricted_features);
RestrictedFeatures& restricted_features);
#ifdef TRITON_ENABLE_TRACING
TRITONSERVER_InferenceTraceLevel ParseTraceLevelOption(std::string arg);
InferenceTraceMode ParseTraceModeOption(std::string arg);
Expand Down
5 changes: 1 addition & 4 deletions src/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once

#include <algorithm>
#include <iostream>
#include <map>
#include <sstream>
Expand Down Expand Up @@ -171,8 +172,4 @@ bool Contains(const std::vector<std::string>& vec, const std::string& str);
/// \return The joint string.
std::string Join(const std::vector<std::string>& vec, const std::string& delim);

using RestrictedFeature = std::pair<std::string, std::string>;

using RestrictedFeatureMap = std::map<std::string, RestrictedFeature>;

}} // namespace triton::server
Loading

0 comments on commit a64ce5d

Please sign in to comment.