From af88d62ce38a208aae28083a622b072bf2d977f3 Mon Sep 17 00:00:00 2001 From: Peng Lu Date: Mon, 13 Mar 2023 10:12:26 +0800 Subject: [PATCH] [Fix] Avoid scope switching when using mmdet inference interface (#2039) --- demo/topdown_demo_with_mmdet.py | 4 +-- .../webcam/nodes/model_nodes/detector_node.py | 4 +-- mmpose/utils/__init__.py | 4 ++- mmpose/utils/config_utils.py | 26 +++++++++++++++++++ 4 files changed, 33 insertions(+), 5 deletions(-) create mode 100644 mmpose/utils/config_utils.py diff --git a/demo/topdown_demo_with_mmdet.py b/demo/topdown_demo_with_mmdet.py index 7ed3107b3a..f1a4e42b4d 100644 --- a/demo/topdown_demo_with_mmdet.py +++ b/demo/topdown_demo_with_mmdet.py @@ -8,13 +8,13 @@ import mmcv import mmengine import numpy as np -from mmengine.registry import init_default_scope from mmpose.apis import inference_topdown from mmpose.apis import init_model as init_pose_estimator from mmpose.evaluation.functional import nms from mmpose.registry import VISUALIZERS from mmpose.structures import merge_data_samples, split_instances +from mmpose.utils import adapt_mmdet_pipeline try: from mmdet.apis import inference_detector, init_detector @@ -28,7 +28,6 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer, """Visualize predicted keypoints (and heatmaps) of one image.""" # predict bbox - init_default_scope(detector.cfg.get('default_scope', 'mmdet')) det_result = inference_detector(detector, img_path) pred_instance = det_result.pred_instances.cpu().numpy() bboxes = np.concatenate( @@ -147,6 +146,7 @@ def main(): # build detector detector = init_detector( args.det_config, args.det_checkpoint, device=args.device) + detector.cfg = adapt_mmdet_pipeline(detector.cfg) # build pose estimator pose_estimator = init_pose_estimator( diff --git a/mmpose/apis/webcam/nodes/model_nodes/detector_node.py b/mmpose/apis/webcam/nodes/model_nodes/detector_node.py index fa925a2d25..350831fe62 100644 --- a/mmpose/apis/webcam/nodes/model_nodes/detector_node.py +++ b/mmpose/apis/webcam/nodes/model_nodes/detector_node.py @@ -2,8 +2,8 @@ from typing import Dict, List, Optional, Union import numpy as np -from mmengine.registry import init_default_scope +from mmpose.utils import adapt_mmdet_pipeline from ...utils import get_config_path from ..node import Node from ..registry import NODES @@ -92,6 +92,7 @@ def __init__(self, # Init model self.model = init_detector( self.model_config, self.model_checkpoint, device=self.device) + self.model.cfg = adapt_mmdet_pipeline(self.model.cfg) # Register buffers self.register_input_buffer(input_buffer, 'input', trigger=True) @@ -109,7 +110,6 @@ def process(self, input_msgs): img = input_msg.get_image() - init_default_scope(self.model.cfg.get('default_scope', 'mmdet')) preds = inference_detector(self.model, img) objects = self._post_process(preds) input_msg.update_objects(objects) diff --git a/mmpose/utils/__init__.py b/mmpose/utils/__init__.py index 044bb286c4..c48ca01cea 100644 --- a/mmpose/utils/__init__.py +++ b/mmpose/utils/__init__.py @@ -1,11 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. from .camera import SimpleCamera, SimpleCameraTorch from .collect_env import collect_env +from .config_utils import adapt_mmdet_pipeline from .logger import get_root_logger from .setup_env import register_all_modules, setup_multi_processes from .timer import StopWatch __all__ = [ 'get_root_logger', 'collect_env', 'StopWatch', 'setup_multi_processes', - 'register_all_modules', 'SimpleCamera', 'SimpleCameraTorch' + 'register_all_modules', 'SimpleCamera', 'SimpleCameraTorch', + 'adapt_mmdet_pipeline' ] diff --git a/mmpose/utils/config_utils.py b/mmpose/utils/config_utils.py new file mode 100644 index 0000000000..2f54d2ef24 --- /dev/null +++ b/mmpose/utils/config_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmpose.utils.typing import ConfigDict + + +def adapt_mmdet_pipeline(cfg: ConfigDict) -> ConfigDict: + """Converts pipeline types in MMDetection's test dataloader to use the + 'mmdet' namespace. + + Args: + cfg (ConfigDict): Configuration dictionary for MMDetection. + + Returns: + ConfigDict: Configuration dictionary with updated pipeline types. + """ + # use lazy import to avoid hard dependence on mmdet + from mmdet.datasets import transforms + + if 'test_dataloader' not in cfg: + return cfg + + pipeline = cfg.test_dataloader.dataset.pipeline + for trans in pipeline: + if trans['type'] in dir(transforms): + trans['type'] = 'mmdet.' + trans['type'] + + return cfg