diff --git a/.gitignore b/.gitignore index 47a09753a2..f1b69cb25e 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ __pycache__ tmp *.log +test_results.txt diff --git a/qa/L0_backend_python/python_based_backends/python_based_backends_test.py b/qa/L0_backend_python/python_based_backends/python_based_backends_test.py new file mode 100644 index 0000000000..13fe204267 --- /dev/null +++ b/qa/L0_backend_python/python_based_backends/python_based_backends_test.py @@ -0,0 +1,144 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import sys +import unittest +from random import randint + +import numpy as np +import tritonclient.grpc as grpcclient +from tritonclient.utils import * + +sys.path.append("../../common") +from test_util import TestResultCollector + + +class PythonBasedBackendsTest(TestResultCollector): + def setUp(self): + self.triton_client = grpcclient.InferenceServerClient(url="localhost:8001") + self.add_sub_model_1 = "add" + self.add_sub_model_2 = "sub" + self.python_model = "add_sub" + self.pytorch_model = "add_sub_pytorch" + + self.triton_client.load_model( + self.add_sub_model_1, + config='{"backend":"add_sub","version_policy":{"latest":{"num_versions":2}}}', + ) + self.triton_client.load_model(self.add_sub_model_2) + self.triton_client.load_model(self.python_model) + self.triton_client.load_model(self.pytorch_model) + + def test_add_sub_models(self): + self.assertTrue( + self.triton_client.is_model_ready(self.add_sub_model_1, model_version="2") + ) + self._test_add_sub_model( + model_name=self.add_sub_model_1, model_version="2", single_output=True + ) + + self.assertTrue( + self.triton_client.is_model_ready(self.add_sub_model_1, model_version="1") + ) + self._test_add_sub_model( + model_name=self.add_sub_model_1, model_version="1", single_output=True + ) + + self.assertTrue(self.triton_client.is_model_ready(self.add_sub_model_2)) + self._test_add_sub_model(model_name=self.add_sub_model_2, single_output=True) + + def test_python_model(self): + self.assertTrue( + self.triton_client.is_model_ready(self.python_model, model_version="2") + ) + self._test_add_sub_model( + model_name=self.python_model, shape=[16], model_version="2" + ) + + def test_pytorch_model(self): + self.assertTrue( + self.triton_client.is_model_ready(self.pytorch_model, model_version="1") + ) + self._test_add_sub_model(model_name=self.pytorch_model) + + def _test_add_sub_model( + self, model_name, model_version="1", shape=[4], single_output=False + ): + input0_data = np.random.rand(*shape).astype(np.float32) + input1_data = np.random.rand(*shape).astype(np.float32) + + inputs = [ + grpcclient.InferInput( + "INPUT0", input0_data.shape, np_to_triton_dtype(input0_data.dtype) + ), + grpcclient.InferInput( + "INPUT1", input1_data.shape, np_to_triton_dtype(input1_data.dtype) + ), + ] + + inputs[0].set_data_from_numpy(input0_data) + inputs[1].set_data_from_numpy(input1_data) + + if single_output: + outputs = [grpcclient.InferRequestedOutput("OUTPUT")] + + else: + outputs = [ + grpcclient.InferRequestedOutput("OUTPUT0"), + grpcclient.InferRequestedOutput("OUTPUT1"), + ] + + response = self.triton_client.infer( + model_name=model_name, + inputs=inputs, + model_version=model_version, + request_id=str(randint(10, 99)), + outputs=outputs, + ) + + if single_output: + if model_name == "add": + self.assertTrue( + np.allclose(input0_data + input1_data, response.as_numpy("OUTPUT")) + ) + else: + self.assertTrue( + np.allclose(input0_data - input1_data, response.as_numpy("OUTPUT")) + ) + else: + self.assertTrue( + np.allclose(input0_data + input1_data, response.as_numpy("OUTPUT0")) + ) + self.assertTrue( + np.allclose(input0_data - input1_data, response.as_numpy("OUTPUT1")) + ) + + def tearDown(self): + self.triton_client.close() + + +if __name__ == "__main__": + unittest.main() diff --git a/qa/L0_backend_python/python_based_backends/test.sh b/qa/L0_backend_python/python_based_backends/test.sh new file mode 100755 index 0000000000..0f332eb3e0 --- /dev/null +++ b/qa/L0_backend_python/python_based_backends/test.sh @@ -0,0 +1,113 @@ +#!/bin/bash +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +source ../../common/util.sh + +TRITON_DIR=${TRITON_DIR:="/opt/tritonserver"} +SERVER=${TRITON_DIR}/bin/tritonserver +BACKEND_DIR=${TRITON_DIR}/backends +QA_MODELS_PATH="../../python_models" +MODEL_REPOSITORY="$(pwd)/models" +SERVER_ARGS="--model-repository=${MODEL_REPOSITORY} --backend-directory=${BACKEND_DIR} --model-control-mode=explicit --log-verbose=1" +SERVER_LOG="./python_based_backends_server.log" +CLIENT_LOG="./python_based_backends_client.log" +TEST_RESULT_FILE="./test_results.txt" +CLIENT_PY="./python_based_backends_test.py" +GEN_PYTORCH_MODEL_PY="../../common/gen_qa_pytorch_model.py" +EXPECTED_NUM_TESTS=3 +RET=0 + +rm -rf ${MODEL_REPOSITORY} +pip3 install torch + +# Setup add_sub backend and models +mkdir -p ${BACKEND_DIR}/add_sub +cp ${QA_MODELS_PATH}/python_based_backends/add_sub_backend/model.py ${BACKEND_DIR}/add_sub/model.py + +mkdir -p ${MODEL_REPOSITORY}/add/1/ +echo '{ "operation": "add" }' > ${MODEL_REPOSITORY}/add/1/model.json +echo "backend: \"add_sub\"" > ${MODEL_REPOSITORY}/add/config.pbtxt +cp -r ${MODEL_REPOSITORY}/add/1/ ${MODEL_REPOSITORY}/add/2/ + +mkdir -p ${MODEL_REPOSITORY}/sub/1/ +echo '{ "operation": "sub" }' > ${MODEL_REPOSITORY}/sub/1/model.json +echo "backend: \"add_sub\"" > ${MODEL_REPOSITORY}/sub/config.pbtxt + +# Setup python backend model +mkdir -p ${MODEL_REPOSITORY}/add_sub/1 +cp ${QA_MODELS_PATH}/add_sub/model.py ${MODEL_REPOSITORY}/add_sub/1/ +cp ${QA_MODELS_PATH}/add_sub/config.pbtxt ${MODEL_REPOSITORY}/add_sub/ +cp -r ${MODEL_REPOSITORY}/add_sub/1/ ${MODEL_REPOSITORY}/add_sub/2/ + +# Setup pytorch backend model +cp ${GEN_PYTORCH_MODEL_PY} ./gen_qa_pytorch_model.py +GEN_PYTORCH_MODEL_PY=./gen_qa_pytorch_model.py + +set +e +python3 ${GEN_PYTORCH_MODEL_PY} -m ${MODEL_REPOSITORY} + +if [ $? -ne 0 ]; then + echo -e "\n***\n*** Running ${GEN_PYTORCH_MODEL_PY} FAILED. \n***" + exit 1 +fi +set -e + +run_server +if [ "$SERVER_PID" == "0" ]; then + cat $SERVER_LOG + echo -e "\n***\n*** Failed to start $SERVER\n***" + exit 1 +fi + +set +e +python3 $CLIENT_PY -v >$CLIENT_LOG 2>&1 + +if [ $? -ne 0 ]; then + echo -e "\n***\n*** Running $CLIENT_PY FAILED. \n***" + RET=1 +else + check_test_results $TEST_RESULT_FILE $EXPECTED_NUM_TESTS + if [ $? -ne 0 ]; then + echo -e "\n***\n*** Test Result Verification FAILED.\n***" + RET=1 + fi +fi +set -e + +kill $SERVER_PID +wait $SERVER_PID +rm -rf ${MODEL_REPOSITORY} ${GEN_PYTORCH_MODEL_PY} + +if [ $RET -eq 1 ]; then + cat $CLIENT_LOG + cat $SERVER_LOG + echo -e "\n***\n*** Python-based Backends test FAILED. \n***" +else + echo -e "\n***\n*** Python-based Backends test PASSED. \n***" +fi + +exit $RET diff --git a/qa/L0_backend_python/test.sh b/qa/L0_backend_python/test.sh index af8000b2aa..23c2ce75b4 100755 --- a/qa/L0_backend_python/test.sh +++ b/qa/L0_backend_python/test.sh @@ -395,7 +395,7 @@ fi # Disable variants test for Jetson since already built without GPU Tensor support # Disable decoupled test because it uses GPU tensors if [ "$TEST_JETSON" == "0" ]; then - SUBTESTS="ensemble io bls decoupled variants" + SUBTESTS="ensemble io bls decoupled variants python_based_backends" for TEST in ${SUBTESTS}; do # Run each subtest in a separate virtual environment to avoid conflicts # between dependencies. diff --git a/qa/common/gen_qa_pytorch_model.py b/qa/common/gen_qa_pytorch_model.py new file mode 100644 index 0000000000..2daee9cffc --- /dev/null +++ b/qa/common/gen_qa_pytorch_model.py @@ -0,0 +1,124 @@ +# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +import argparse +import os + +import torch +from torch import nn + + +class AddSubNet(nn.Module): + def __init__(self): + super(AddSubNet, self).__init__() + + def forward(self, input0, input1): + return (input0 + input1), (input0 - input1) + + +def generate_model(model_dir): + model = AddSubNet() + + traced_model = torch.jit.trace( + model, + (torch.rand(1, 4, dtype=torch.float), torch.rand(1, 4, dtype=torch.float)), + ) + + os.makedirs(model_dir, exist_ok=True) + model_path = os.path.join(model_dir, "model.pt") + + traced_model.save(model_path) + + +def generate_config(config_path): + with open(f"{config_path}/config.pbtxt", "w") as f: + f.write( + """ +backend: "pytorch" +input [ + { + name: "INPUT0" + data_type: TYPE_FP32 + dims: [ 4 ] + } +] +input [ + { + name: "INPUT1" + data_type: TYPE_FP32 + dims: [ 4 ] + } +] +output [ + { + name: "OUTPUT0" + data_type: TYPE_FP32 + dims: [ 4 ] + } +] +output [ + { + name: "OUTPUT1" + data_type: TYPE_FP32 + dims: [ 4 ] + } +] +""" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", + "--model-directory", + type=str, + required=True, + help="The path to the model repository.", + ) + parser.add_argument( + "--model-name", + type=str, + required=False, + default="add_sub_pytorch", + help="Model name", + ) + parser.add_argument( + "--version", + type=str, + required=False, + default="1", + help="Model version", + ) + + args = parser.parse_args() + + model_directory = os.path.join(args.model_directory, args.model_name) + os.makedirs(model_directory, exist_ok=True) + + generate_model(model_dir=os.path.join(model_directory, args.version)) + generate_config(model_directory) diff --git a/qa/python_models/python_based_backends/add_sub_backend/model.py b/qa/python_models/python_based_backends/add_sub_backend/model.py new file mode 100644 index 0000000000..7c9736b2d5 --- /dev/null +++ b/qa/python_models/python_based_backends/add_sub_backend/model.py @@ -0,0 +1,162 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import os + +import triton_python_backend_utils as pb_utils + +_ADD_SUB_ARGS_FILENAME = "model.json" + + +class TritonPythonModel: + @staticmethod + def auto_complete_config(auto_complete_model_config): + """This function is called only once when loading the model assuming + the server was not started with `--disable-auto-complete-config`. + + Parameters + ---------- + auto_complete_model_config : pb_utils.ModelConfig + An object containing the existing model configuration. + + Returns + ------- + pb_utils.ModelConfig + An object containing the auto-completed model configuration + """ + inputs = [ + {"name": "INPUT0", "data_type": "TYPE_FP32", "dims": [4]}, + {"name": "INPUT1", "data_type": "TYPE_FP32", "dims": [4]}, + ] + outputs = [{"name": "OUTPUT", "data_type": "TYPE_FP32", "dims": [4]}] + + config = auto_complete_model_config.as_dict() + input_names = [] + output_names = [] + + for input in config["input"]: + input_names.append(input["name"]) + + for output in config["output"]: + output_names.append(output["name"]) + + for input in inputs: + if input["name"] not in input_names: + auto_complete_model_config.add_input(input) + + for output in outputs: + if output["name"] not in output_names: + auto_complete_model_config.add_output(output) + + return auto_complete_model_config + + def initialize(self, args): + """This function allows the model to initialize any state associated with this model. + + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + + self.model_config = model_config = json.loads(args["model_config"]) + + # Get OUTPUT configuration + output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT") + + engine_args_filepath = os.path.join( + pb_utils.get_model_dir(), _ADD_SUB_ARGS_FILENAME + ) + assert os.path.isfile( + engine_args_filepath + ), f"'{_ADD_SUB_ARGS_FILENAME}' containing add sub model args must be provided in '{pb_utils.get_model_dir()}'" + + with open(engine_args_filepath) as file: + self.add_sub_config = json.load(file) + + assert ( + "operation" in self.add_sub_config + ), f"Missing required key 'operation' in {_ADD_SUB_ARGS_FILENAME}" + + extra_keys = set(self.add_sub_config.keys()) - {"operation"} + assert ( + not extra_keys + ), f"Unsupported keys are provided in {_ADD_SUB_ARGS_FILENAME}: {', '.join(extra_keys)}" + + assert self.add_sub_config["operation"] in [ + "add", + "sub", + ], f"'operation' value must be 'add' or 'sub' in {_ADD_SUB_ARGS_FILENAME}" + + # Convert Triton types to numpy types + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + def execute(self, requests): + """This function is called when an inference request is made + for this model. + + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + for request in requests: + in_0 = pb_utils.get_input_tensor_by_name(request, "INPUT0") + in_1 = pb_utils.get_input_tensor_by_name(request, "INPUT1") + + if self.add_sub_config["operation"] == "add": + out = in_0.as_numpy() + in_1.as_numpy() + else: + out = in_0.as_numpy() - in_1.as_numpy() + + # Create output tensors. + out_tensor = pb_utils.Tensor("OUTPUT", out.astype(self.output_dtype)) + + # Create InferenceResponse. + inference_response = pb_utils.InferenceResponse(output_tensors=[out_tensor]) + responses.append(inference_response) + + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded.""" + print("Cleaning up...")