diff --git a/.gitignore b/.gitignore index eb428883fe..bdecd24034 100644 --- a/.gitignore +++ b/.gitignore @@ -1,9 +1,24 @@ bazel-* +build +mediapipe.egg-info +mediapipe/__pycache__/ mediapipe/MediaPipe.xcodeproj mediapipe/MediaPipe.tulsiproj/*.tulsiconf-user +mediapipe/models/ovms/face_detection_short_range/ +mediapipe/models/ovms/face_landmark/ +mediapipe/models/ovms/hand_landmark_full/ +mediapipe/models/ovms/hand_recrop/ +mediapipe/models/ovms/iris_landmark/ +mediapipe/models/ovms/palm_detection_full/ +mediapipe/models/ovms/pose_detection/ +mediapipe/models/ovms/pose_landmark_full/ +mediapipe/models/ovms/ssdlite_object_detection/ +mediapipe/models/ssdlite_object_detection_labelmap.txt mediapipe/provisioning_profile.mobileprovision +mediapipe/python/__pycache__/ node_modules/ .configure.bazelrc .user.bazelrc .vscode/ .vs/ +*.mp4 diff --git a/Dockerfile.openvino b/Dockerfile.openvino index 282a5246ac..f52f2557cb 100644 --- a/Dockerfile.openvino +++ b/Dockerfile.openvino @@ -85,12 +85,6 @@ RUN apt-get update && apt-get install --no-install-recommends -y \ libopencv-imgproc-dev \ libopencv-video-dev \ build-essential \ - libboost-filesystem1.71.0 \ - libboost-thread1.71.0 \ - libboost-program-options1.71.0 \ - libboost-chrono1.71.0 \ - libboost-date-time1.71.0 \ - libboost-atomic1.71.0 \ libjson-c4 \ unzip diff --git a/WORKSPACE b/WORKSPACE index 611a12e1b4..64dce0ccbe 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -638,10 +638,12 @@ http_archive( build_file = "@//third_party:halide.BUILD", ) +load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository") + git_repository( name = "ovms", remote = "https://github.com/openvinotoolkit/model_server", - commit = "77c30dc3f153b3ee78336a3a75c09af4e23c14a4", # MP update to 10.3 in OVMS + commit = "ad1381fde838f2ac2d23117df78c186a96134fcc", # Fix azure patch (#2107) ) # DEV ovms - adjust local repository path for build diff --git a/mediapipe/__init__.py b/mediapipe/__init__.py index bf29fed398..d812d5d2e6 100644 --- a/mediapipe/__init__.py +++ b/mediapipe/__init__.py @@ -1,13 +1,13 @@ -# Copyright 2019 - 2022 The MediaPipe Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. + +from mediapipe.python import * +import mediapipe.python.solutions as solutions +import mediapipe.tasks.python as tasks + + +del framework +del gpu +del modules +del python +del mediapipe +del util +__version__ = '1.0' diff --git a/mediapipe/examples/python/README.md b/mediapipe/examples/python/README.md new file mode 100644 index 0000000000..e3ac725f4b --- /dev/null +++ b/mediapipe/examples/python/README.md @@ -0,0 +1,33 @@ +# OVMS python examples +- Building docker container with dependencies +```bash +git clone https://github.com/openvinotoolkit/mediapipe.git +cd mediapipe +make docker_build +``` + +- Start the container +```bash +docker run -it mediapipe_ovms:latest bash +``` + +- Prepare models for ovms +```bash +python setup_ovms.py --get_models' +``` + +- Build and install mediapipe python package +Make sure you are in /mediapipe dirctory +Below command takes around 1 hour depending on your internet speed and cpu +```bash +pip install . +``` + +- Run example ovms python script +```bash +python build/lib.linux-x86_64-cpython-38/mediapipe/examples/python/ovms_object_detection.py +``` + +- This script will run object detection on input video, as described in this c++ example +[OVMS Object Detection](../desktop/object_detection/README.md) +[Original demo documentation](https://google.github.io/mediapipe/solutions/object_detection) \ No newline at end of file diff --git a/mediapipe/examples/python/__init__.py b/mediapipe/examples/python/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/mediapipe/examples/python/ovms_object_detection.py b/mediapipe/examples/python/ovms_object_detection.py new file mode 100644 index 0000000000..f4c1ad27f5 --- /dev/null +++ b/mediapipe/examples/python/ovms_object_detection.py @@ -0,0 +1,21 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import mediapipe as mp +ovms_object_detection = mp.solutions.ovms_object_detection +with ovms_object_detection.OvmsObjectDetection(side_inputs= + {'input_video_path':'/mediapipe/mediapipe/examples/desktop/object_detection/test_video.mp4', + 'output_video_path':'/mediapipe/tested_video.mp4'}) as ovms_object_detection: + results = ovms_object_detection.process() diff --git a/mediapipe/modules/ovms_modules/BUILD b/mediapipe/modules/ovms_modules/BUILD new file mode 100644 index 0000000000..9883ac7a21 --- /dev/null +++ b/mediapipe/modules/ovms_modules/BUILD @@ -0,0 +1,39 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +load( + "//mediapipe/framework/tool:mediapipe_graph.bzl", + "mediapipe_simple_subgraph", +) +load( + "//mediapipe/framework/tool:mediapipe_files.bzl", + "mediapipe_files", +) +load("//mediapipe/framework/port:build_config.bzl", "mediapipe_proto_library") +load("//mediapipe/framework:mediapipe_cc_test.bzl", "mediapipe_cc_test") + +licenses(["notice"]) + +package(default_visibility = ["//visibility:public"]) + +mediapipe_simple_subgraph( + name = "object_detection_ovms", + graph = "object_detection_ovms.pbtxt", + register_as = "ObjectDetectionOvms", + deps = [ + "//mediapipe/graphs/object_detection:desktop_ovms_calculators", + "@ovms//src:ovms_lib", + ], +) diff --git a/mediapipe/modules/ovms_modules/object_detection_ovms.pbtxt b/mediapipe/modules/ovms_modules/object_detection_ovms.pbtxt new file mode 100644 index 0000000000..a58a1f8ba9 --- /dev/null +++ b/mediapipe/modules/ovms_modules/object_detection_ovms.pbtxt @@ -0,0 +1,206 @@ +# MediaPipe graph that performs object detection on desktop with OpenVINO Model Server +# on CPU. +# Used in the example in +# mediapipe/examples/desktop/object_detection:object_detection_openvino. + +# max_queue_size limits the number of packets enqueued on any input stream +# by throttling inputs to the graph. This makes the graph only process one +# frame per time. +max_queue_size: 1 + +# Decodes an input video file into images and a video header. +node { + calculator: "OpenCvVideoDecoderCalculator" + input_side_packet: "INPUT_FILE_PATH:input_video_path" + output_stream: "VIDEO:input_video" + output_stream: "VIDEO_PRESTREAM:input_video_header" +} + +# Transforms the input image on CPU to a 320x320 image. To scale the image, by +# default it uses the STRETCH scale mode that maps the entire input image to the +# entire transformed image. As a result, image aspect ratio may be changed and +# objects in the image may be deformed (stretched or squeezed), but the object +# detection model used in this graph is agnostic to that deformation. +node: { + calculator: "ImageTransformationCalculator" + input_stream: "IMAGE:input_video" + output_stream: "IMAGE:transformed_input_video" + node_options: { + [type.googleapis.com/mediapipe.ImageTransformationCalculatorOptions] { + output_width: 320 + output_height: 320 + } + } +} + +# Converts the transformed input image on CPU into an image tensor as a +# OpenVINOTensor. The zero_center option is set to true to normalize the +# pixel values to [-1.f, 1.f] as opposed to [0.f, 1.f]. +node { + calculator: "OpenVINOConverterCalculator" + input_stream: "IMAGE:transformed_input_video" + output_stream: "TENSORS:image_tensor" + node_options: { + [type.googleapis.com/mediapipe.OpenVINOConverterCalculatorOptions] { + enable_normalization: true + zero_center: true + } + } +} + +# Runs a TensorFlow Lite model on CPU that takes an image tensor and outputs a +# vector of tensors representing, for instance, detection boxes/keypoints and +# scores. +node { + calculator: "OpenVINOModelServerSessionCalculator" + output_side_packet: "SESSION:session" + node_options: { + [type.googleapis.com / mediapipe.OpenVINOModelServerSessionCalculatorOptions]: { + servable_name: "ssdlite_object_detection" # servable name inside OVMS + servable_version: "1" + server_config: "mediapipe/calculators/ovms/config.json" + } + } +} +node { + calculator: "OpenVINOInferenceCalculator" + input_side_packet: "SESSION:session" + input_stream: "OVTENSORS:image_tensor" + output_stream: "OVTENSORS2:detection_tensors" + node_options: { + [type.googleapis.com / mediapipe.OpenVINOInferenceCalculatorOptions]: { + input_order_list :["normalized_input_image_tensor"] + output_order_list :["raw_outputs/box_encodings","raw_outputs/class_predictions"] + tag_to_input_tensor_names { + key: "OVTENSORS" + value: "normalized_input_image_tensor" + } + tag_to_output_tensor_names { + key: "OVTENSORS1" + value: "raw_outputs/box_encodings" + } + tag_to_output_tensor_names { + key: "OVTENSORS2" + value: "raw_outputs/class_predictions" + } + } + } +} + +# Generates a single side packet containing a vector of SSD anchors based on +# the specification in the options. +node { + calculator: "SsdAnchorsCalculator" + output_side_packet: "anchors" + node_options: { + [type.googleapis.com/mediapipe.SsdAnchorsCalculatorOptions] { + num_layers: 6 + min_scale: 0.2 + max_scale: 0.95 + input_size_height: 320 + input_size_width: 320 + anchor_offset_x: 0.5 + anchor_offset_y: 0.5 + strides: 16 + strides: 32 + strides: 64 + strides: 128 + strides: 256 + strides: 512 + aspect_ratios: 1.0 + aspect_ratios: 2.0 + aspect_ratios: 0.5 + aspect_ratios: 3.0 + aspect_ratios: 0.3333 + reduce_boxes_in_lowest_layer: true + } + } +} + +# Decodes the detection tensors generated by the TensorFlow Lite model, based on +# the SSD anchors and the specification in the options, into a vector of +# detections. Each detection describes a detected object. +node { + calculator: "OpenVINOTensorsToDetectionsCalculator" + input_stream: "TENSORS:detection_tensors" + input_side_packet: "ANCHORS:anchors" + output_stream: "DETECTIONS:detections" + node_options: { + [type.googleapis.com/mediapipe.OpenVINOTensorsToDetectionsCalculatorOptions] { + num_classes: 91 + num_boxes: 2034 + num_coords: 4 + ignore_classes: 0 + apply_exponential_on_box_size: true + + x_scale: 10.0 + y_scale: 10.0 + h_scale: 5.0 + w_scale: 5.0 + } + } +} + +# Performs non-max suppression to remove excessive detections. +node { + calculator: "NonMaxSuppressionCalculator" + input_stream: "detections" + output_stream: "filtered_detections" + node_options: { + [type.googleapis.com/mediapipe.NonMaxSuppressionCalculatorOptions] { + min_suppression_threshold: 0.4 + min_score_threshold: 0.6 + max_num_detections: 5 + overlap_type: INTERSECTION_OVER_UNION + } + } +} + +# Maps detection label IDs to the corresponding label text. The label map is +# provided in the label_map_path option. +node { + calculator: "DetectionLabelIdToTextCalculator" + input_stream: "filtered_detections" + output_stream: "output_detections" + node_options: { + [type.googleapis.com/mediapipe.DetectionLabelIdToTextCalculatorOptions] { + label_map_path: "/mediapipe/mediapipe/models/ssdlite_object_detection_labelmap.txt" + } + } +} + +# Converts the detections to drawing primitives for annotation overlay. +node { + calculator: "DetectionsToRenderDataCalculator" + input_stream: "DETECTIONS:output_detections" + output_stream: "RENDER_DATA:render_data" + node_options: { + [type.googleapis.com/mediapipe.DetectionsToRenderDataCalculatorOptions] { + thickness: 4.0 + color { r: 255 g: 0 b: 0 } + } + } +} + +# Draws annotations and overlays them on top of the input images. +node { + calculator: "AnnotationOverlayCalculator" + input_stream: "IMAGE:input_video" + input_stream: "render_data" + output_stream: "IMAGE:output_video" +} + +# Encodes the annotated images into a video file, adopting properties specified +# in the input video header, e.g., video framerate. +node { + calculator: "OpenCvVideoEncoderCalculator" + input_stream: "VIDEO:output_video" + input_stream: "VIDEO_PRESTREAM:input_video_header" + input_side_packet: "OUTPUT_FILE_PATH:output_video_path" + node_options: { + [type.googleapis.com/mediapipe.OpenCvVideoEncoderCalculatorOptions]: { + codec: "avc1" + video_format: "mp4" + } + } +} diff --git a/mediapipe/python/BUILD b/mediapipe/python/BUILD index 085fbc96bc..91b78dd3c5 100644 --- a/mediapipe/python/BUILD +++ b/mediapipe/python/BUILD @@ -58,6 +58,8 @@ pybind_extension( "//mediapipe/framework/formats:rect_registration", "//mediapipe/modules/objectron/calculators:annotation_registration", "//mediapipe/tasks/cc/vision/face_geometry/proto:face_geometry_registration", + # OVMS lib + "@ovms//src:ovms_lib", ], ) diff --git a/mediapipe/python/solutions/__init__.py b/mediapipe/python/solutions/__init__.py index 3490699520..ec50720901 100644 --- a/mediapipe/python/solutions/__init__.py +++ b/mediapipe/python/solutions/__init__.py @@ -23,5 +23,6 @@ import mediapipe.python.solutions.hands_connections import mediapipe.python.solutions.holistic import mediapipe.python.solutions.objectron +import mediapipe.python.solutions.ovms_object_detection import mediapipe.python.solutions.pose import mediapipe.python.solutions.selfie_segmentation diff --git a/mediapipe/python/solutions/ovms_object_detection.py b/mediapipe/python/solutions/ovms_object_detection.py new file mode 100644 index 0000000000..fb658a5455 --- /dev/null +++ b/mediapipe/python/solutions/ovms_object_detection.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Ovms Object Detection.""" + +from mediapipe.calculators.ovms import openvinoinferencecalculator_pb2 +from mediapipe.calculators.ovms import openvinomodelserversessioncalculator_pb2 +from mediapipe.python.solution_base import SolutionBase + +_FULL_GRAPH_FILE_PATH = 'mediapipe/modules/ovms_modules/object_detection_ovms.binarypb' + +class OvmsObjectDetection(SolutionBase): + """Ovms Object Detection. + + Ovms Object Detection processes an input video returns output video + with detectec objects. + """ + """ + Oryginal params in desktop example + --calculator_graph_config_file mediapipe/graphs/object_detection/object_detection_desktop_ovms1_graph.pbtxt + --input_side_packets "input_video_path=/mediapipe/mediapipe/examples/desktop/object_detection/test_video.mp4,output_video_path=/mediapipe/tested_video.mp4 + """ + def __init__(self, + side_inputs= + {'input_video_path':'/mediapipe/mediapipe/examples/desktop/object_detection/test_video.mp4', + 'output_video_path':'/mediapipe/tested_video.mp4'}): + """Initializes a Ovms Object Detection object. + """ + super().__init__( + binary_graph_path=_FULL_GRAPH_FILE_PATH, + side_inputs=side_inputs) + + def process(self): + self._graph.wait_until_done() + return None diff --git a/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto b/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto index ec11df2b47..232f206f74 100644 --- a/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto +++ b/mediapipe/tasks/cc/components/processors/proto/detection_postprocessing_graph_options.proto @@ -27,23 +27,23 @@ import "mediapipe/tasks/cc/components/calculators/score_calibration_calculator.p message DetectionPostprocessingGraphOptions { // Optional SsdAnchorsCalculatorOptions for models without // non-maximum-suppression in tflite model graph. - optional mediapipe.SsdAnchorsCalculatorOptions ssd_anchors_options = 1; + mediapipe.SsdAnchorsCalculatorOptions ssd_anchors_options = 1; // Optional TensorsToDetectionsCalculatorOptions for models without // non-maximum-suppression in tflite model graph. - optional mediapipe.TensorsToDetectionsCalculatorOptions + mediapipe.TensorsToDetectionsCalculatorOptions tensors_to_detections_options = 2; // Optional NonMaxSuppressionCalculatorOptions for models without // non-maximum-suppression in tflite model graph. - optional mediapipe.NonMaxSuppressionCalculatorOptions + mediapipe.NonMaxSuppressionCalculatorOptions non_max_suppression_options = 3; // Optional score calibration options for models with non-maximum-suppression // in tflite model graph. - optional ScoreCalibrationCalculatorOptions score_calibration_options = 4; + ScoreCalibrationCalculatorOptions score_calibration_options = 4; // Optional detection label id to text calculator options. - optional mediapipe.DetectionLabelIdToTextCalculatorOptions + mediapipe.DetectionLabelIdToTextCalculatorOptions detection_label_ids_to_text_options = 5; } diff --git a/setup.py b/setup.py index 73d0ec3c30..fc72267951 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ ] GPU_OPTIONS = GPU_OPTIONS_DISBALED if MP_DISABLE_GPU else GPU_OPTIONS_ENBALED +OVMS_OPTIONS = ['--define=MEDIAPIPE_DISABLE=1 --define=PYTHON_DISABLE=1 --cxxopt=-DPYTHON_DISABLE=1 --cxxopt=-DMEDIAPIPE_DISABLE=1'] def _normalize_path(path): return path.replace('\\', '/') if IS_WINDOWS else path @@ -131,6 +132,8 @@ def _add_mp_init_files(): # Save the original mediapipe/__init__.py file. shutil.copyfile(MP_DIR_INIT_PY, _get_backup_file(MP_DIR_INIT_PY)) mp_dir_init_file = open(MP_DIR_INIT_PY, 'a') + # Removes licence but clears contents so that it doesnt get messed up with every build + mp_dir_init_file.truncate(0); mp_dir_init_file.writelines([ '\n', 'from mediapipe.python import *\n', 'import mediapipe.python.solutions as solutions \n', @@ -272,7 +275,8 @@ def run(self): 'hand_landmark/hand_landmark_tracking_cpu', 'holistic_landmark/holistic_landmark_cpu', 'objectron/objectron_cpu', 'pose_landmark/pose_landmark_cpu', - 'selfie_segmentation/selfie_segmentation_cpu' + 'selfie_segmentation/selfie_segmentation_cpu', + 'ovms_modules/object_detection_ovms' ] for elem in binary_graphs: binary_graph = os.path.join('mediapipe/modules/', elem) @@ -300,7 +304,7 @@ def _generate_binary_graph(self, binary_graph_target): '--copt=-DNDEBUG', '--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable), binary_graph_target, - ] + GPU_OPTIONS + ] + GPU_OPTIONS + OVMS_OPTIONS if not self.link_opencv and not IS_WINDOWS: bazel_command.append('--define=OPENCV=source') @@ -326,7 +330,7 @@ def run(self): '--compilation_mode=opt', '--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable), '//mediapipe/tasks/metadata:' + target, - ] + GPU_OPTIONS + ] + GPU_OPTIONS + OVMS_OPTIONS _invoke_shell_command(bazel_command) _copy_to_build_lib_dir( @@ -413,7 +417,7 @@ def _build_binary(self, ext, extra_args=None): '--copt=-DNDEBUG', '--action_env=PYTHON_BIN_PATH=' + _normalize_path(sys.executable), str(ext.bazel_target + '.so'), - ] + GPU_OPTIONS + ] + GPU_OPTIONS + OVMS_OPTIONS if extra_args: bazel_command += extra_args