Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Бенчмаркинг PaddlePaddle] Реализация пайплайна для PaddlePaddle #508

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
8f8ca1e
init
IgorKonovalovAleks Mar 3, 2024
d39ce82
style fix, input fix
IgorKonovalovAleks Mar 10, 2024
5853169
style fix
IgorKonovalovAleks Mar 10, 2024
56d0529
gpu device support
IgorKonovalovAleks Mar 19, 2024
3ccef36
style 1
IgorKonovalovAleks Mar 20, 2024
a216280
parser arguments + output handling fix
IgorKonovalovAleks Mar 25, 2024
fcedaec
parameters parser
IgorKonovalovAleks Apr 20, 2024
571a573
init
IgorKonovalovAleks Mar 3, 2024
e4f28c8
style fix, input fix
IgorKonovalovAleks Mar 10, 2024
8e3bbaf
style fix
IgorKonovalovAleks Mar 10, 2024
8eae024
gpu device support
IgorKonovalovAleks Mar 19, 2024
d30fbdc
style 1
IgorKonovalovAleks Mar 20, 2024
472f920
parser arguments + output handling fix
IgorKonovalovAleks Mar 25, 2024
62177bc
parameters parser
IgorKonovalovAleks Apr 20, 2024
770b1df
config + model download
IgorKonovalovAleks Apr 21, 2024
12a50a3
Merge remote-tracking branch 'origin/paddlepaddle_inference' into pad…
IgorKonovalovAleks Apr 21, 2024
a59847c
config fix
IgorKonovalovAleks Apr 21, 2024
1358078
config fix 1
IgorKonovalovAleks Apr 21, 2024
1658a07
framework wrapper
IgorKonovalovAleks Apr 21, 2024
c1de22c
command line fix
IgorKonovalovAleks Apr 21, 2024
2eb6594
command line fix 1
IgorKonovalovAleks Apr 21, 2024
9f18ee2
command line fix 2
IgorKonovalovAleks Apr 21, 2024
40b2ef4
command line fix 3
IgorKonovalovAleks Apr 23, 2024
b1ec5ab
command line fix 4
IgorKonovalovAleks Apr 28, 2024
9529fb4
requirements update
IgorKonovalovAleks Apr 28, 2024
6f47a12
args
IgorKonovalovAleks Apr 28, 2024
a8603be
command line fix 5
IgorKonovalovAleks Apr 29, 2024
54dbec5
model path and input shape requirement
IgorKonovalovAleks May 8, 2024
14bb7e6
Merge branch 'master' into paddlepaddle_inference
IgorKonovalovAleks May 8, 2024
9368060
model path fix
IgorKonovalovAleks May 8, 2024
86b695e
Merge remote-tracking branch 'origin/paddlepaddle_inference' into pad…
IgorKonovalovAleks May 8, 2024
4dc1cff
Dockerfile added
IgorKonovalovAleks May 11, 2024
beeaed1
validation results
IgorKonovalovAleks Jun 9, 2024
8433b04
Merge branch 'master' into paddlepaddle_inference
IgorKonovalovAleks Sep 12, 2024
34d14b3
paddlepaddle url updated
IgorKonovalovAleks Sep 19, 2024
e8c8b19
paddlepaddle readme updates
IgorKonovalovAleks Oct 10, 2024
b4dbc87
Merge remote-tracking branch 'origin/paddlepaddle_inference' into pad…
IgorKonovalovAleks Oct 10, 2024
c0406bb
docker fix
IgorKonovalovAleks Oct 24, 2024
66f1e71
requested fixes
IgorKonovalovAleks Oct 26, 2024
7628abf
line length fix
IgorKonovalovAleks Oct 26, 2024
e94eb90
requested changes 2
IgorKonovalovAleks Oct 28, 2024
9141c43
general try-catch
IgorKonovalovAleks Oct 31, 2024
c28dacb
try-except format
IgorKonovalovAleks Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 214 additions & 0 deletions src/inference/inference_paddlepaddle.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import paddle.inference as paddle_infer
import argparse
import json
import sys
from pathlib import Path

import postprocessing_data as pp
import preprocessing_data as prep
from inference_tools.loop_tools import get_exec_time
from inference_tools.loop_tools import loop_inference
from io_adapter import IOAdapter
from io_model_wrapper import PaddlePaddleIOModelWrapper
from reporter.report_writer import ReportWriter
from transformer import PaddlePaddleTransformer

sys.path.append(str(Path(__file__).resolve().parents[1].joinpath('utils')))
from logger_conf import configure_logger # noqa: E402
log = configure_logger()


def cli_argument_parser():
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model',
help='Path to a .pdmodel file.',
required=True,
type=str,
dest='model_path')
parser.add_argument('-p', '--params',
help='Path to .pdiparams file.',
required=True,
type=str,
dest='params_path')
parser.add_argument('-i', '--input',
help='Path to data',
required=False,
default=None,
type=str,
nargs='+',
dest='input')
parser.add_argument('-b', '--batch_size',
help='Size of the processed pack',
default=1,
type=int,
dest='batch_size')
parser.add_argument('-t', '--task',
help='Output processing method. Default: without postprocess',
choices=['classification'],
default='feedforward',
type=str,
dest='task')
parser.add_argument('-ni', '--number_iter',
help='Number of inference iterations',
default=1,
type=int,
dest='number_iter')
parser.add_argument('--raw_output',
help='Raw output without logs',
default=False,
type=bool,
dest='raw_output')
parser.add_argument('--mean',
help='Parameter mean',
default=None,
type=str,
dest='mean')
parser.add_argument('--input_scale',
help='Parameter input scale',
default=None,
type=str,
dest='input_scale')
parser.add_argument('--layout',
help='Parameter input layout',
default=None,
type=str,
dest='layout')
parser.add_argument('--input_shapes',
help='Input tensor shapes',
default=None,
type=str,
dest='input_shapes',
required=True)
parser.add_argument('--input_names',
help='Names of the input tensors',
default=None,
type=prep.names_arg,
dest='input_names')
parser.add_argument('--output_names',
help='Name of the output tensor',
default=None,
type=str,
nargs='+',
dest='output_names')
parser.add_argument('-d', '--device',
help='Specify the target device to infer on (CPU by default)',
default='CPU',
type=str,
dest='device')
parser.add_argument('-nt', '--number_top',
help='Number of top results to print',
default=5,
type=int,
dest='number_top')
parser.add_argument('--report_path',
type=Path,
default=Path(__file__).parent / 'paddle_inference_report.json',
dest='report_path')
parser.add_argument('--time', required=False, default=0, type=int,
dest='time',
help='Optional. Time in seconds to execute topology.')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Удалите, пожалуйста, точку в конце предложения перед закрытием кавычки

parser.add_argument('--memory_pool_init_size_mb', required=False, default=1000, type=int,
dest='memory_pool_init_size_mb', help='Initial size of the the allocated gpu memory, in MB')
valentina-kustikova marked this conversation as resolved.
Show resolved Hide resolved

args = parser.parse_args()

return args


def inference_paddlepaddle(predictor, number_iter, get_slice, test_duration):
result = None
input_info = predictor.get_input_names()
outputs = predictor.get_output_names()
if number_iter > 1:
time_infer, _ = loop_inference(number_iter, test_duration)(inference_iteration)(get_slice,
input_info, predictor)
else:
exec_time = inference_iteration(get_slice, input_info, predictor)
result = {}
for name in outputs:
output_tensor = predictor.get_output_handle(name)
output_data = output_tensor.copy_to_cpu()
valentina-kustikova marked this conversation as resolved.
Show resolved Hide resolved
result[name] = output_data
time_infer = [exec_time]
return result, time_infer


def inference_iteration(get_slice, input_info, predictor):
for name, data in get_slice().items():
input_tensor = predictor.get_input_handle(name)
input_tensor.copy_from_cpu(data)
_, exec_time = infer_slice(predictor)
return exec_time


@get_exec_time()
def infer_slice(predictor):
predictor.run()


def prepare_output(result, output_names, task):
if (output_names is None) or (len(result) != len(output_names)):
raise ValueError('The number of output tensors does not match the number of corresponding output names')
if task == 'classification':
return result
else:
raise ValueError(f'Unsupported task {task} to print inference results')


def main():
args = cli_argument_parser()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

В основной функции точно не надо все взять в try...except? Наверняка где-то могут генерироваться исключения, которые сейчас не отлавливаются

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IgorKonovalovAleks, вопрос остался открытым.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@IgorKonovalovAleks, здесь вопрос также открыт.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

добавил


report_writer = ReportWriter()
report_writer.update_framework_info(name='PaddlePaddle', version=paddle_infer.get_version())
report_writer.update_configuration_setup(batch_size=args.batch_size,
iterations_num=args.number_iter,
target_device=args.device)

config = paddle_infer.Config(args.model_path, args.params_path)
config.enable_memory_optim()
if args.device == 'GPU':
config.enable_use_gpu(args.memory_pool_init_size_mb, 0)
valentina-kustikova marked this conversation as resolved.
Show resolved Hide resolved
predictor = paddle_infer.create_predictor(config)
args.input_shapes = prep.parse_input_arg(args.input_shapes, args.input_names)
for name in predictor.get_input_names():
predictor.get_input_handle(name).reshape(args.input_shapes[name])
model_wrapper = PaddlePaddleIOModelWrapper(predictor)

args.mean = prep.parse_input_arg(args.mean, args.input_names)
args.input_scale = prep.parse_input_arg(args.input_scale, args.input_names)
args.layout = prep.parse_layout_arg(args.layout, args.input_names)

data_transformer = PaddlePaddleTransformer(prep.create_dict_for_transformer(args, 'NHWC'))
io = IOAdapter.get_io_adapter(args, model_wrapper, data_transformer)

if args.input and args.input != ['None']:
log.info(f'Preparing input data: {args.input}')
io.prepare_input(predictor, args.input)
else:
io.fill_unset_inputs(predictor, log)

log.info(f'Starting inference ({args.number_iter} iterations)')
result, inference_time = inference_paddlepaddle(predictor, args.number_iter, io.get_slice_input, args.time)

inference_result = pp.calculate_performance_metrics_sync_mode(args.batch_size, inference_time)

report_writer.update_execution_results(**inference_result)
log.info(f'Write report to {args.report_path}')
report_writer.write_report(args.report_path)

if not args.raw_output:
if args.number_iter == 1:
try:
log.info('Converting output tensor to print results')
result = prepare_output(result, args.output_names, args.task)

log.info('Inference results')
io.process_output(result, log)
except Exception as ex:
log.warning('Error when printing inference results. {0}'.format(str(ex)))

log.info(f'Performance results:\n{json.dumps(inference_result, indent=4)}')


if __name__ == '__main__':
sys.exit(main() or 0)
15 changes: 15 additions & 0 deletions src/inference/io_model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,21 @@ def get_input_layer_dtype(self, model, layer_name):
return float32


class PaddlePaddleIOModelWrapper(IOModelWrapper):
def __init__(self, predictor):
self._input_names = predictor.get_input_names()

def get_input_layer_names(self, model):
return list(self._input_names)

def get_input_layer_shape(self, predictor, layer_name):
return predictor.get_input_handle(layer_name).shape()

def get_input_layer_dtype(self, predictor, layer_name):
from numpy import float32
return float32


class TVMIOModelWrapper(IOModelWrapper):
def __init__(self, args):
self._input_names = [args['input_name']]
Expand Down
4 changes: 4 additions & 0 deletions src/inference/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ class PyTorchTransformer(TensorFlowLiteTransformer):
pass


class PaddlePaddleTransformer(TensorFlowLiteTransformer):
pass


class ONNXRuntimeTransformer(TensorFlowLiteTransformer):
pass

Expand Down
Loading