Skip to content

Commit

Permalink
Fixed passing of bf16 and tf.string types to convert_model().
Browse files Browse the repository at this point in the history
  • Loading branch information
popovaan committed Apr 18, 2024
1 parent bf832a4 commit 9fe63a9
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 11 deletions.
21 changes: 21 additions & 0 deletions tests/layer_tests/ovc_python_api_tests/test_complex_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,27 @@ def test_ovc_convert_model_with_several_output(self, ie_device, precision, ir_ve
self._test(temp_dir, convert_model_params, cli_tool_params)


@pytest.mark.nightly
@pytest.mark.precommit
def test_non_numpy_types(self, ie_device, precision, ir_version, temp_dir, use_legacy_frontend):
import tensorflow as tf
def func(a, b):
return [a, b]
model = tf.function(func, input_signature=[tf.TensorSpec([2], tf.float32, "a"),
tf.TensorSpec([2], tf.float32, "b")])
parameter1 = ov.opset8.parameter(ov.Shape([2]), ov.Type.bf16)
parameter2 = ov.opset8.parameter(ov.Shape([2]), ov.Type.bf16)
bf16_ref = ov.Model([parameter1, parameter2], [parameter1, parameter2])

parameter1 = ov.opset8.parameter(ov.Shape([2]), ov.Type.string)
parameter2 = ov.opset8.parameter(ov.Shape([2]), ov.Type.string)
string_ref = ov.Model([parameter1, parameter2], [parameter1, parameter2])

self._test_by_ref_graph(temp_dir, {'input_model': model, 'input': [ov.Type.bf16, tf.bfloat16]}, bf16_ref, compare_tensor_names=False)
self._test_by_ref_graph(temp_dir, {'input_model': model, 'input': {'a': ov.Type.bf16, 'b': tf.bfloat16}}, bf16_ref, compare_tensor_names=False)
self._test_by_ref_graph(temp_dir, {'input_model': model, 'input': [ov.Type.string, tf.string]}, string_ref, compare_tensor_names=False)
self._test_by_ref_graph(temp_dir, {'input_model': model, 'input': {'a': ov.Type.string, 'b': tf.string}}, string_ref, compare_tensor_names=False)

class NegativeCases(unittest.TestCase):
test_directory = os.path.dirname(os.path.realpath(__file__))

Expand Down
14 changes: 6 additions & 8 deletions tools/ovc/openvino/tools/ovc/convert_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from pathlib import Path
from typing import Iterable, Callable

import numpy as np

try:
import openvino_telemetry as tm
from openvino_telemetry.backend import backend_ga4
Expand Down Expand Up @@ -343,10 +345,8 @@ def normalize_inputs(argv: argparse.Namespace):
shape_dict[inp.name] = None
if inp.type is not None:
# Convert type to numpy type for uniformity of stored values
if isinstance(inp.type, str):
data_type_dict[inp.name] = destination_type_to_np_data_type(inp.type)
elif isinstance(inp.type, Type):
data_type_dict[inp.name] = inp.type.to_dtype().type
if isinstance(inp.type, (np.dtype, str)):
data_type_dict[inp.name] = Type(inp.type)
else:
data_type_dict[inp.name] = inp.type
argv.placeholder_shapes = shape_dict if shape_dict else None
Expand All @@ -361,10 +361,8 @@ def normalize_inputs(argv: argparse.Namespace):
shape_list.append(PartialShape(inp.shape))
if inp.type is not None:
# Convert type to numpy type for uniformity of stored values
if isinstance(inp.type, str):
data_type_list.append(destination_type_to_np_data_type(inp.type))
elif isinstance(inp.type, Type):
data_type_list.append(inp.type.to_dtype().type)
if isinstance(inp.type, (np.dtype, str)):
data_type_list.append(Type(inp.type))
else:
data_type_list.append(inp.type)
argv.placeholder_shapes = shape_list if shape_list else None
Expand Down
2 changes: 1 addition & 1 deletion tools/ovc/openvino/tools/ovc/moc_frontend/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def create_target_input_shapes(new_input_places):
input_model.set_partial_shape(
user_shape['node'], user_shape['shape'])
if user_shape.get('data_type') is not None:
data_type = get_element_type(user_shape['data_type'])
data_type = user_shape['data_type']
log.debug('Set data type: {}'.format(data_type))
input_model.set_element_type(user_shape['node'], data_type)

Expand Down
19 changes: 17 additions & 2 deletions tools/ovc/openvino/tools/ovc/moc_frontend/type_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,22 @@ def to_ov_type(val):
if 'tensorflow' in sys.modules:
import tensorflow as tf # pylint: disable=import-error
if isinstance(val, tf.dtypes.DType):
return Type(val.as_numpy_dtype())
tf_to_ov_type = {
tf.float32: ov.Type.f32,
tf.float16: ov.Type.f16,
tf.float64: ov.Type.f64,
tf.bfloat16: ov.Type.bf16,
tf.uint8: ov.Type.u8,
tf.int8: ov.Type.i8,
tf.int16: ov.Type.i16,
tf.int32: ov.Type.i32,
tf.int64: ov.Type.i64,
tf.bool: ov.Type.boolean,
tf.string: ov.Type.string
}
if val not in tf_to_ov_type:
raise Exception("The provided data time is not supported {}.".format(val))
return tf_to_ov_type[val]
if 'torch' in sys.modules:
import torch

Expand All @@ -48,7 +63,7 @@ def to_ov_type(val):
torch.int16: ov.Type.i16,
torch.int32: ov.Type.i32,
torch.int64: ov.Type.i64,
torch.bool: ov.Type.boolean,
torch.bool: ov.Type.boolean
}
if val not in torch_to_ov_type:
raise Exception("The provided data time is not supported {}.".format(val))
Expand Down

0 comments on commit 9fe63a9

Please sign in to comment.