Skip to content

Commit

Permalink
Enabling option to restrict access to HTTP APIs based on header value…
Browse files Browse the repository at this point in the history
… pairs (similar to gRPC)
  • Loading branch information
nnshah1 authored Nov 3, 2023
1 parent c1b334f commit 4b481a6
Show file tree
Hide file tree
Showing 15 changed files with 530 additions and 175 deletions.
19 changes: 18 additions & 1 deletion qa/L0_grpc/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,24 @@ SERVER_ARGS="--model-repository=${MODELDIR} \
--grpc-restricted-protocol=model-repository,health:k1=v1 \
--grpc-restricted-protocol=metadata,health:k2=v2"
run_server
EXPECTED_MSG="protocol 'health' can not be specified in multiple config group"
EXPECTED_MSG="protocol 'health' can not be specified in multiple config groups"
if [ "$SERVER_PID" != "0" ]; then
echo -e "\n***\n*** Expect fail to start $SERVER\n***"
kill $SERVER_PID
wait $SERVER_PID
RET=1
elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then
echo -e "\n***\n*** Failed. Expected ${EXPECTED_MSG} to be found in log\n***"
cat $SERVER_LOG
RET=1
fi

# Unknown protocol, not allowed
SERVER_ARGS="--model-repository=${MODELDIR} \
--grpc-restricted-protocol=model-reposit,health:k1=v1 \
--grpc-restricted-protocol=metadata,health:k2=v2"
run_server
EXPECTED_MSG="unknown restricted protocol 'model-reposit'"
if [ "$SERVER_PID" != "0" ]; then
echo -e "\n***\n*** Expect fail to start $SERVER\n***"
kill $SERVER_PID
Expand Down
94 changes: 94 additions & 0 deletions qa/L0_http/http_restricted_api_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
#!/usr/bin/python
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import sys

sys.path.append("../common")

import unittest

import numpy as np
import tritonclient.http as tritonhttpclient
from tritonclient.utils import InferenceServerException


class RestrictedAPITest(unittest.TestCase):
def setUp(self):
self.model_name_ = "simple"
self.client_ = tritonhttpclient.InferenceServerClient("localhost:8000")

# Other unspecified APIs should not be restricted
def test_sanity(self):
self.client_.get_inference_statistics("simple")
self.client_.get_inference_statistics(
"simple", headers={"infer-key": "infer-value"}
)

# metadata, infer, model repository APIs are restricted.
# metadata and infer expects "infer-key : infer-value" header,
# model repository expected "admin-key : admin-value".
def test_model_repository(self):
with self.assertRaisesRegex(InferenceServerException, "This API is restricted"):
self.client_.unload_model(
self.model_name_, headers={"infer-key": "infer-value"}
)
# Request go through and get actual transaction error
with self.assertRaisesRegex(
InferenceServerException, "explicit model load / unload is not allowed"
):
self.client_.unload_model(
self.model_name_, headers={"admin-key": "admin-value"}
)

def test_metadata(self):
with self.assertRaisesRegex(InferenceServerException, "This API is restricted"):
self.client_.get_server_metadata()
self.client_.get_server_metadata({"infer-key": "infer-value"})

def test_infer(self):
# setup
inputs = [
tritonhttpclient.InferInput("INPUT0", [1, 16], "INT32"),
tritonhttpclient.InferInput("INPUT1", [1, 16], "INT32"),
]
inputs[0].set_data_from_numpy(np.ones(shape=(1, 16), dtype=np.int32))
inputs[1].set_data_from_numpy(np.ones(shape=(1, 16), dtype=np.int32))

# This test only care if the request goes through
with self.assertRaisesRegex(InferenceServerException, "This API is restricted"):
_ = self.client_.infer(
model_name=self.model_name_, inputs=inputs, headers={"test": "1"}
)
self.client_.infer(
model_name=self.model_name_,
inputs=inputs,
headers={"infer-key": "infer-value"},
)


if __name__ == "__main__":
unittest.main()
71 changes: 69 additions & 2 deletions qa/L0_http/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ RET=0

CLIENT_PLUGIN_TEST="./http_client_plugin_test.py"
BASIC_AUTH_TEST="./http_basic_auth_test.py"
RESTRICTED_API_TEST="./http_restricted_api_test.py"
NGINX_CONF="./nginx.conf"
# On windows the paths invoked by the script (running in WSL) must use
# /mnt/c when needed but the paths on the tritonserver command-line
Expand Down Expand Up @@ -588,13 +589,15 @@ kill $SERVER_PID
wait $SERVER_PID

# Run python unit test
rm -r ${MODELDIR}/*
MODELDIR=python_unit_test_models
mkdir -p $MODELDIR
rm -rf ${MODELDIR}/*
cp -r $DATADIR/qa_identity_model_repository/onnx_zero_1_float32 ${MODELDIR}/.
cp -r $DATADIR/qa_identity_model_repository/onnx_zero_1_object ${MODELDIR}/.
cp -r $DATADIR/qa_identity_model_repository/onnx_zero_1_float16 ${MODELDIR}/.
cp -r $DATADIR/qa_identity_model_repository/onnx_zero_3_float32 ${MODELDIR}/.
cp -r ${MODELDIR}/onnx_zero_1_object ${MODELDIR}/onnx_zero_1_object_1_element && \
(cd models/onnx_zero_1_object_1_element && \
(cd $MODELDIR/onnx_zero_1_object_1_element && \
sed -i "s/onnx_zero_1_object/onnx_zero_1_object_1_element/" config.pbtxt && \
sed -i "0,/-1/{s/-1/1/}" config.pbtxt)

Expand Down Expand Up @@ -667,6 +670,70 @@ set -e
kill $SERVER_PID
wait $SERVER_PID

### Test Restricted APIs ###
### Repeated API not allowed

MODELDIR="`pwd`/models"
SERVER_ARGS="--model-repository=${MODELDIR}
--http-restricted-api=model-repository,health:k1=v1 \
--http-restricted-api=metadata,health:k2=v2"
run_server
EXPECTED_MSG="api 'health' can not be specified in multiple config groups"
if [ "$SERVER_PID" != "0" ]; then
echo -e "\n***\n*** Expect fail to start $SERVER\n***"
kill $SERVER_PID
wait $SERVER_PID
RET=1
elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then
echo -e "\n***\n*** Failed. Expected ${EXPECTED_MSG} to be found in log\n***"
cat $SERVER_LOG
RET=1
fi

### Test Unknown Restricted API###
### Unknown API not allowed

MODELDIR="`pwd`/models"
SERVER_ARGS="--model-repository=${MODELDIR}
--http-restricted-api=model-reposit,health:k1=v1 \
--http-restricted-api=metadata,health:k2=v2"
run_server
EXPECTED_MSG="unknown restricted api 'model-reposit'"
if [ "$SERVER_PID" != "0" ]; then
echo -e "\n***\n*** Expect fail to start $SERVER\n***"
kill $SERVER_PID
wait $SERVER_PID
RET=1
elif [ `grep -c "${EXPECTED_MSG}" ${SERVER_LOG}` != "1" ]; then
echo -e "\n***\n*** Failed. Expected ${EXPECTED_MSG} to be found in log\n***"
cat $SERVER_LOG
RET=1
fi

### Test Restricted APIs ###
### Restricted model-repository, metadata, and inference

SERVER_ARGS="--model-repository=${MODELDIR} \
--http-restricted-api=model-repository:admin-key=admin-value \
--http-restricted-api=inference,metadata:infer-key=infer-value"
run_server
if [ "$SERVER_PID" == "0" ]; then
echo -e "\n***\n*** Failed to start $SERVER\n***"
cat $SERVER_LOG
exit 1
fi
set +e

python $RESTRICTED_API_TEST RestrictedAPITest > $CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Python HTTP Restricted Protocol Test Failed\n***"
RET=1
fi
set -e
kill $SERVER_PID
wait $SERVER_PID

###

if [ $RET -eq 0 ]; then
Expand Down
Loading

0 comments on commit 4b481a6

Please sign in to comment.