From 30154f0a79a9ff230da4fdae90cc39c43ddca70e Mon Sep 17 00:00:00 2001 From: ly015 Date: Sun, 12 Mar 2023 04:21:49 +0800 Subject: [PATCH 1/2] update get_flops.py --- mmpose/apis/inferencers/pose2d_inferencer.py | 17 ++- tools/analysis_tools/get_flops.py | 106 +++++++------------ 2 files changed, 51 insertions(+), 72 deletions(-) diff --git a/mmpose/apis/inferencers/pose2d_inferencer.py b/mmpose/apis/inferencers/pose2d_inferencer.py index 30ebd7c711..2f67605019 100644 --- a/mmpose/apis/inferencers/pose2d_inferencer.py +++ b/mmpose/apis/inferencers/pose2d_inferencer.py @@ -4,7 +4,6 @@ import mmcv import numpy as np -from mmdet.apis.det_inferencer import DetInferencer from mmengine.config import Config, ConfigDict from mmengine.infer.infer import ModelType from mmengine.registry import init_default_scope @@ -17,6 +16,12 @@ from .base_mmpose_inferencer import BaseMMPoseInferencer from .utils import default_det_models +try: + from mmdet.apis.det_inferencer import DetInferencer + mmdet_available = True +except (ImportError, ModuleNotFoundError): + mmdet_available = False + InstanceList = List[InstanceData] InputType = Union[str, np.ndarray] InputsType = Union[InputType, Sequence[InputType]] @@ -94,8 +99,14 @@ def __init__(self, det_model, det_weights, det_cat_ids = det_info[ 'model'], det_info['weights'], det_info['cat_ids'] - self.detector = DetInferencer( - det_model, det_weights, device=device) + if mmdet_available: + self.detector = DetInferencer( + det_model, det_weights, device=device) + else: + raise RuntimeError( + 'MMDetection is required to build inferencers ' + 'for top-down pose estimation models.') + if isinstance(det_cat_ids, (tuple, list)): self.det_cat_ids = det_cat_ids else: diff --git a/tools/analysis_tools/get_flops.py b/tools/analysis_tools/get_flops.py index 9325037699..b62a320909 100644 --- a/tools/analysis_tools/get_flops.py +++ b/tools/analysis_tools/get_flops.py @@ -1,73 +1,34 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse -from functools import partial import torch -from mmengine.config import DictAction - -from mmpose.apis.inference import init_model try: - from mmcv.cnn import get_model_complexity_info + from fvcore.nn import (ActivationCountAnalysis, FlopCountAnalysis, + flop_count_str, flop_count_table, parameter_count) except ImportError: - raise ImportError('Please upgrade mmcv to >0.6.2') + print('You may need to install fvcore for flops computation, ' + 'and you can use `pip install fvcore` to set up the environment') +from fvcore.nn.print_model_statistics import _format_size +from mmengine import Config + +from mmpose.models import build_pose_estimator +from mmpose.utils import register_all_modules def parse_args(): - parser = argparse.ArgumentParser(description='Train a recognizer') - parser.add_argument('config', help='train config file path') - parser.add_argument( - '--device', - default='cuda:0', - help='Device used for model initialization') - parser.add_argument( - '--cfg-options', - nargs='+', - action=DictAction, - default={}, - help='override some settings in the used config, the key-value pair ' - 'in xxx=yyy format will be merged into config file. For example, ' - "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") + parser = argparse.ArgumentParser(description='Get model flops and params') + parser.add_argument('config', help='config file path') parser.add_argument( '--shape', type=int, nargs='+', default=[256, 192], help='input image size') - parser.add_argument( - '--input-constructor', - '-c', - type=str, - choices=['none', 'batch'], - default='none', - help='If specified, it takes a callable method that generates ' - 'input. Otherwise, it will generate a random tensor with ' - 'input shape to calculate FLOPs.') - parser.add_argument( - '--batch-size', '-b', type=int, default=1, help='input batch size') - parser.add_argument( - '--not-print-per-layer-stat', - '-n', - action='store_true', - help='Whether to print complexity information' - 'for each layer in a model') args = parser.parse_args() return args -def batch_constructor(flops_model, batch_size, input_shape): - """Generate a batch of tensors to the model.""" - batch = {} - - inputs = torch.ones(()).new_empty( - (batch_size, *input_shape), - dtype=next(flops_model.parameters()).dtype, - device=next(flops_model.parameters()).device) - - batch['inputs'] = inputs - return batch - - def main(): args = parse_args() @@ -79,37 +40,44 @@ def main(): else: raise ValueError('invalid input shape') - model = init_model( - args.config, - checkpoint=None, - device=args.device, - cfg_options=args.cfg_options) - - if args.input_constructor == 'batch': - input_constructor = partial(batch_constructor, model, args.batch_size) - else: - input_constructor = None + cfg = Config.fromfile(args.config) + model = build_pose_estimator(cfg.model) + model.eval() - if hasattr(model, '_forward'): - model.forward = model._forward + if hasattr(model, 'extract_feat'): + model.forward = model.extract_feat else: raise NotImplementedError( 'FLOPs counter is currently not currently supported with {}'. format(model.__class__.__name__)) - flops, params = get_model_complexity_info( - model, - input_shape, - input_constructor=input_constructor, - print_per_layer_stat=(not args.not_print_per_layer_stat)) + inputs = (torch.randn((1, *input_shape)), ) + flops_ = FlopCountAnalysis(model, inputs) + activations_ = ActivationCountAnalysis(model, inputs) + + flops = _format_size(flops_.total()) + activations = _format_size(activations_.total()) + params = _format_size(parameter_count(model)['']) + + flop_table = flop_count_table( + flops=flops_, + activations=activations_, + show_param_shapes=True, + ) + flop_str = flop_count_str(flops=flops_, activations=activations_) + + print('\n' + flop_str) + print('\n' + flop_table) + split_line = '=' * 30 - input_shape = (args.batch_size, ) + input_shape print(f'{split_line}\nInput shape: {input_shape}\n' - f'Flops: {flops}\nParams: {params}\n{split_line}') + f'Flops: {flops}\nParams: {params}\n' + f'Activation: {activations}\n{split_line}') print('!!!Please be cautious if you use the results in papers. ' 'You may need to check if all ops are supported and verify that the ' 'flops computation is correct.') if __name__ == '__main__': + register_all_modules() main() From 3b0e12d08572ec8a18fb6fce81e874c87d250f27 Mon Sep 17 00:00:00 2001 From: ly015 Date: Mon, 13 Mar 2023 11:37:31 +0800 Subject: [PATCH 2/2] use mmengine --- mmpose/models/pose_estimators/base.py | 2 +- tools/analysis_tools/get_flops.py | 88 ++++++++++----------- tools/analysis_tools/get_flops1.py | 106 ++++++++++++++++++++++++++ 3 files changed, 152 insertions(+), 44 deletions(-) create mode 100644 tools/analysis_tools/get_flops1.py diff --git a/mmpose/models/pose_estimators/base.py b/mmpose/models/pose_estimators/base.py index b97232b344..057f0cf9e6 100644 --- a/mmpose/models/pose_estimators/base.py +++ b/mmpose/models/pose_estimators/base.py @@ -90,7 +90,7 @@ def _load_metainfo(metainfo: dict = None) -> dict: def forward(self, inputs: torch.Tensor, - data_samples: OptSampleList, + data_samples: OptSampleList = None, mode: str = 'tensor') -> ForwardResults: """The unified entry for a forward process in both training and test. diff --git a/tools/analysis_tools/get_flops.py b/tools/analysis_tools/get_flops.py index b62a320909..3c2c5c47fc 100644 --- a/tools/analysis_tools/get_flops.py +++ b/tools/analysis_tools/get_flops.py @@ -1,83 +1,85 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse -import torch +from mmengine.config import Config, DictAction + +from mmpose.registry import MODELS +from mmpose.utils import register_all_modules try: - from fvcore.nn import (ActivationCountAnalysis, FlopCountAnalysis, - flop_count_str, flop_count_table, parameter_count) + from mmengine.analysis import get_model_complexity_info except ImportError: - print('You may need to install fvcore for flops computation, ' - 'and you can use `pip install fvcore` to set up the environment') -from fvcore.nn.print_model_statistics import _format_size -from mmengine import Config - -from mmpose.models import build_pose_estimator -from mmpose.utils import register_all_modules + raise ImportError('Please upgrade mmcv to >0.6.2') def parse_args(): - parser = argparse.ArgumentParser(description='Get model flops and params') - parser.add_argument('config', help='config file path') + parser = argparse.ArgumentParser(description='Train a detector') + parser.add_argument('config', help='train config file path') parser.add_argument( '--shape', type=int, nargs='+', - default=[256, 192], + default=[1280, 800], help='input image size') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') args = parser.parse_args() return args def main(): - + register_all_modules() args = parse_args() if len(args.shape) == 1: - input_shape = (3, args.shape[0], args.shape[0]) + h = w = args.shape[0] elif len(args.shape) == 2: - input_shape = (3, ) + tuple(args.shape) + h, w = args.shape else: raise ValueError('invalid input shape') + input_shape = (3, h, w) cfg = Config.fromfile(args.config) - model = build_pose_estimator(cfg.model) - model.eval() - - if hasattr(model, 'extract_feat'): - model.forward = model.extract_feat - else: - raise NotImplementedError( - 'FLOPs counter is currently not currently supported with {}'. - format(model.__class__.__name__)) - - inputs = (torch.randn((1, *input_shape)), ) - flops_ = FlopCountAnalysis(model, inputs) - activations_ = ActivationCountAnalysis(model, inputs) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) - flops = _format_size(flops_.total()) - activations = _format_size(activations_.total()) - params = _format_size(parameter_count(model)['']) + model = MODELS.build(cfg.model) + model.eval() - flop_table = flop_count_table( - flops=flops_, - activations=activations_, - show_param_shapes=True, - ) - flop_str = flop_count_str(flops=flops_, activations=activations_) + analysis_results = get_model_complexity_info( + model, input_shape, show_table=True, show_arch=False) - print('\n' + flop_str) - print('\n' + flop_table) + # ayalysis_results = { + # 'flops': flops, + # 'flops_str': flops_str, + # 'activations': activations, + # 'activations_str': activations_str, + # 'params': params, + # 'params_str': params_str, + # 'out_table': complexity_table, + # 'out_arch': complexity_arch + # } split_line = '=' * 30 print(f'{split_line}\nInput shape: {input_shape}\n' - f'Flops: {flops}\nParams: {params}\n' - f'Activation: {activations}\n{split_line}') + f'Flops: {analysis_results["flops"]}\n' + f'Params: {analysis_results["params"]}\n{split_line}') + + print(analysis_results['activations']) + # print(analysis_results['complexity_table']) + # print(complexity_str) print('!!!Please be cautious if you use the results in papers. ' 'You may need to check if all ops are supported and verify that the ' 'flops computation is correct.') if __name__ == '__main__': - register_all_modules() main() diff --git a/tools/analysis_tools/get_flops1.py b/tools/analysis_tools/get_flops1.py new file mode 100644 index 0000000000..ab603e1e3b --- /dev/null +++ b/tools/analysis_tools/get_flops1.py @@ -0,0 +1,106 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse + +import torch +from mmengine.config import DictAction + +from mmpose.apis.inference import init_model + +try: + # from mmcv.cnn import get_model_complexity_info + from mmengine.analysis import get_model_complexity_info +except ImportError: + raise ImportError('Please upgrade mmcv to >0.6.2') + + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a recognizer') + parser.add_argument('config', help='train config file path') + parser.add_argument( + '--device', default='cpu', help='Device used for model initialization') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + default={}, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. For example, ' + "'--cfg-options model.backbone.depth=18 model.backbone.with_cp=True'") + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[256, 192], + help='input image size') + parser.add_argument( + '--input-constructor', + '-c', + type=str, + choices=['none', 'batch'], + default='none', + help='If specified, it takes a callable method that generates ' + 'input. Otherwise, it will generate a random tensor with ' + 'input shape to calculate FLOPs.') + parser.add_argument( + '--batch-size', '-b', type=int, default=1, help='input batch size') + parser.add_argument( + '--not-print-per-layer-stat', + '-n', + action='store_true', + help='Whether to print complexity information' + 'for each layer in a model') + args = parser.parse_args() + return args + + +def batch_constructor(flops_model, batch_size, input_shape): + """Generate a batch of tensors to the model.""" + batch = {} + + inputs = torch.ones(()).new_empty( + (batch_size, *input_shape), + dtype=next(flops_model.parameters()).dtype, + device=next(flops_model.parameters()).device) + + batch['inputs'] = inputs + return batch + + +def main(): + + args = parse_args() + + if len(args.shape) == 1: + input_shape = (3, args.shape[0], args.shape[0]) + elif len(args.shape) == 2: + input_shape = (3, ) + tuple(args.shape) + else: + raise ValueError('invalid input shape') + + model = init_model( + args.config, + checkpoint=None, + device=args.device, + cfg_options=args.cfg_options) + + if hasattr(model, '_forward'): + model.forward = model._forward + else: + raise NotImplementedError( + 'FLOPs counter is currently not currently supported with {}'. + format(model.__class__.__name__)) + + analysis_results = get_model_complexity_info(model, input_shape) + flops = analysis_results['flops_str'] + params = analysis_results['params_str'] + split_line = '=' * 30 + input_shape = (args.batch_size, ) + input_shape + print(f'{split_line}\nInput shape: {input_shape}\n' + f'Flops: {flops}\nParams: {params}\n{split_line}') + print('!!!Please be cautious if you use the results in papers. ' + 'You may need to check if all ops are supported and verify that the ' + 'flops computation is correct.') + + +if __name__ == '__main__': + main()