Skip to content

Commit

Permalink
[Fix] Avoid scope switching when using mmdet inference interface (ope…
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored Mar 13, 2023
1 parent a5fdb41 commit 6b1ba8e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
4 changes: 2 additions & 2 deletions demo/topdown_demo_with_mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,13 @@
import mmcv
import mmengine
import numpy as np
from mmengine.registry import init_default_scope

from mmpose.apis import inference_topdown
from mmpose.apis import init_model as init_pose_estimator
from mmpose.evaluation.functional import nms
from mmpose.registry import VISUALIZERS
from mmpose.structures import merge_data_samples, split_instances
from mmpose.utils import adapt_mmdet_pipeline

try:
from mmdet.apis import inference_detector, init_detector
Expand All @@ -28,7 +28,6 @@ def process_one_image(args, img_path, detector, pose_estimator, visualizer,
"""Visualize predicted keypoints (and heatmaps) of one image."""

# predict bbox
init_default_scope(detector.cfg.get('default_scope', 'mmdet'))
det_result = inference_detector(detector, img_path)
pred_instance = det_result.pred_instances.cpu().numpy()
bboxes = np.concatenate(
Expand Down Expand Up @@ -147,6 +146,7 @@ def main():
# build detector
detector = init_detector(
args.det_config, args.det_checkpoint, device=args.device)
detector.cfg = adapt_mmdet_pipeline(detector.cfg)

# build pose estimator
pose_estimator = init_pose_estimator(
Expand Down
4 changes: 2 additions & 2 deletions mmpose/apis/webcam/nodes/model_nodes/detector_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Dict, List, Optional, Union

import numpy as np
from mmengine.registry import init_default_scope

from mmpose.utils import adapt_mmdet_pipeline
from ...utils import get_config_path
from ..node import Node
from ..registry import NODES
Expand Down Expand Up @@ -92,6 +92,7 @@ def __init__(self,
# Init model
self.model = init_detector(
self.model_config, self.model_checkpoint, device=self.device)
self.model.cfg = adapt_mmdet_pipeline(self.model.cfg)

# Register buffers
self.register_input_buffer(input_buffer, 'input', trigger=True)
Expand All @@ -109,7 +110,6 @@ def process(self, input_msgs):

img = input_msg.get_image()

init_default_scope(self.model.cfg.get('default_scope', 'mmdet'))
preds = inference_detector(self.model, img)
objects = self._post_process(preds)
input_msg.update_objects(objects)
Expand Down
4 changes: 3 additions & 1 deletion mmpose/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .camera import SimpleCamera, SimpleCameraTorch
from .collect_env import collect_env
from .config_utils import adapt_mmdet_pipeline
from .logger import get_root_logger
from .setup_env import register_all_modules, setup_multi_processes
from .timer import StopWatch

__all__ = [
'get_root_logger', 'collect_env', 'StopWatch', 'setup_multi_processes',
'register_all_modules', 'SimpleCamera', 'SimpleCameraTorch'
'register_all_modules', 'SimpleCamera', 'SimpleCameraTorch',
'adapt_mmdet_pipeline'
]
26 changes: 26 additions & 0 deletions mmpose/utils/config_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmpose.utils.typing import ConfigDict


def adapt_mmdet_pipeline(cfg: ConfigDict) -> ConfigDict:
"""Converts pipeline types in MMDetection's test dataloader to use the
'mmdet' namespace.
Args:
cfg (ConfigDict): Configuration dictionary for MMDetection.
Returns:
ConfigDict: Configuration dictionary with updated pipeline types.
"""
# use lazy import to avoid hard dependence on mmdet
from mmdet.datasets import transforms

if 'test_dataloader' not in cfg:
return cfg

pipeline = cfg.test_dataloader.dataset.pipeline
for trans in pipeline:
if trans['type'] in dir(transforms):
trans['type'] = 'mmdet.' + trans['type']

return cfg

0 comments on commit 6b1ba8e

Please sign in to comment.