Skip to content

Commit

Permalink
reorg code: better hide methods in py wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
akuporos committed Jan 10, 2025
1 parent 2cfe32e commit 2469726
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 47 deletions.
2 changes: 1 addition & 1 deletion src/bindings/python/src/openvino/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from openvino._ov_api import CompiledModel
from openvino._ov_api import InferRequest
from openvino._ov_api import AsyncInferQueue
from openvino._ov_api import Op

from openvino.runtime import Symbol
from openvino.runtime import Dimension
Expand All @@ -55,7 +56,6 @@

from openvino._pyopenvino import RemoteContext
from openvino._pyopenvino import RemoteTensor
from openvino._pyopenvino import Op
from openvino._pyopenvino import OpExtension

# Import opsets
Expand Down
12 changes: 12 additions & 0 deletions src/bindings/python/src/openvino/_ov_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openvino._pyopenvino import Core as CoreBase
from openvino._pyopenvino import CompiledModel as CompiledModelBase
from openvino._pyopenvino import AsyncInferQueue as AsyncInferQueueBase
from openvino._pyopenvino import Op as OpBase
from openvino._pyopenvino import Tensor
from openvino._pyopenvino import Node

Expand All @@ -24,6 +25,17 @@
)


class Op(OpBase):
def __init__(self, py_obj, inputs=None) -> None:
super().__init__(py_obj)
print("super ok")
self._initialize_type_info()
print("__initialize_type_info ok")
if inputs is not None:
self.set_arguments(inputs)
self.constructor_validate_and_infer_types()


class Model:
def __init__(self, *args: Any, **kwargs: Any) -> None:
if args and not kwargs:
Expand Down
33 changes: 22 additions & 11 deletions src/bindings/python/src/pyopenvino/graph/op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ std::shared_ptr<ov::Node> PyOp::clone_with_new_inputs(const ov::OutputVector& ne
}

const ov::op::Op::type_info_t& PyOp::get_type_info() const {
std::cout << reinterpret_cast<size_t>(m_type_info.get()) << std::endl;
// std::cout << reinterpret_cast<size_t>(m_type_info.get()) << std::endl;
return *m_type_info;
}

Expand All @@ -57,6 +57,23 @@ bool PyOp::has_evaluate() const {
PYBIND11_OVERRIDE(bool, ov::op::Op, has_evaluate);
}

void PyOp::initialize_type_info() {
py::gil_scoped_acquire gil; // Acquire the GIL while in this scope.
std::cout << "jello" << std::endl;
// Try to look up the overridden method on the Python side.
py::function overriden_py_method = pybind11::get_override(this, "get_type_info");
if (overriden_py_method) {
// TODO: Write a test on this behavior!
std::cout << "use type name" << std::endl;
const auto type_info_from_py = overriden_py_method();
if (!py::isinstance<ov::DiscreteTypeInfo>(type_info_from_py)) {
// TODO: Rewrite me?
OPENVINO_THROW("operation type_info must be an instance of DiscreteTypeInfo, but ", py::str(py::type::of(type_info_from_py)), " is passed.");
}
m_type_info = type_info_from_py.cast<std::shared_ptr<ov::DiscreteTypeInfo>>();
}
}

void regclass_graph_Op(py::module m) {
py::class_<ov::op::Op, std::shared_ptr<ov::op::Op>, PyOp, ov::Node> op(m, "Op");

Expand All @@ -65,14 +82,8 @@ void regclass_graph_Op(py::module m) {
return PyOp(py_obj);
}));

op.def(py::init([](const py::object& py_obj, const py::object& inputs) {
return PyOp(py_obj, inputs);
}));

// op.def(py::init([](const py::object& py_obj, py::object& inputs) {
// if (inputs.is_none()) {
// return PyOp(py_obj);
// }
// return PyOp(py_obj);
// }), py::arg("py_obj"), py::arg("inputs"));
op.def("_initialize_type_info", [](PyOp& self){
std::cout << "init ti" << std::endl;
self.initialize_type_info();
});
}
24 changes: 6 additions & 18 deletions src/bindings/python/src/pyopenvino/graph/op.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,13 @@ class PyOp : public ov::op::Op {
PyOp() = default;

// Keeps a reference to the Python object to manage its lifetime
PyOp(const py::object& py_obj, const py::object& inputs=py::none()) : py_handle(py_obj) {
py::gil_scoped_acquire gil; // Acquire the GIL while in this scope.
std::cout << "jello" << std::endl;
// Try to look up the overridden method on the Python side.
py::function overrided_py_method = pybind11::get_override(this, "get_type_info");
if (overrided_py_method) { // method is found
const auto result = overrided_py_method(); // Call the Python function.
m_type_info = result.cast<std::shared_ptr<DiscreteTypeInfoWrapper>>();
} else {
const auto py_class_name = py_handle.get_type().attr("__name__").cast<std::string>();
m_type_info = std::make_shared<DiscreteTypeInfoWrapper>(py_class_name, "extension");
}
if (!inputs.is_none()) {
std::cout << "here" << std::endl;
this->set_arguments(inputs.cast<const ov::NodeVector>());
this->constructor_validate_and_infer_types();
}
PyOp(const py::object& py_obj) : py_handle(py_obj) {
const auto py_class_name = py_handle.get_type().attr("__name__").cast<std::string>();
m_type_info = std::make_shared<DiscreteTypeInfoWrapper>(py_class_name, "extension");
}

void initialize_type_info();

void validate_and_infer_types() override;

bool visit_attributes(ov::AttributeVisitor& value) override;
Expand All @@ -53,7 +41,7 @@ class PyOp : public ov::op::Op {

private:
py::object py_handle; // Holds the Python object to manage its lifetime
std::shared_ptr<DiscreteTypeInfoWrapper> m_type_info;
std::shared_ptr<ov::DiscreteTypeInfo> m_type_info;
};

void regclass_graph_Op(py::module m);
8 changes: 7 additions & 1 deletion src/bindings/python/src/pyopenvino/graph/op_extension.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,17 @@ class PyOpExtension : public ov::BaseOpExtension {
py::object type_info;
try {
// get_type_info() is a static method
std::cout << "before static methods" << std::endl;
type_info = py_handle_dtype.attr("get_type_info")();
std::cout << "after static methods" << std::endl;
} catch (const std::exception&) {
try {
// get_type_info() is a class method
type_info = py_handle_dtype().attr("get_type_info")();
std::cout << "before class methods" << std::endl;
auto obj = py_handle_dtype();
std::cout << "afte obj" << std::endl;
type_info = obj.attr("get_type_info")();
std::cout << "afte class methods" << std::endl;
} catch (const std::exception &exc) {
OPENVINO_THROW("Creation of OpExtension failed: ", exc.what());
}
Expand Down
31 changes: 15 additions & 16 deletions src/bindings/python/tests/test_graph/test_custom_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,13 @@ class CustomAdd(Op):
class_type_info = DiscreteTypeInfo("CustomAdd", "extension")

def __init__(self, inputs=None):
super().__init__(self)
if inputs is not None:
self.set_arguments(inputs)
self.constructor_validate_and_infer_types()
super().__init__(self, inputs)
# if inputs is not None:
# self.set_arguments(inputs)
# self.constructor_validate_and_infer_types()

def validate_and_infer_types(self):
print("from validate")
self.set_output_type(0, self.get_input_element_type(0), self.get_input_partial_shape(0))

def clone_with_new_inputs(self, new_inputs):
Expand All @@ -87,11 +88,12 @@ class CustomOpWithAttribute(Op):
class_type_info = DiscreteTypeInfo("CustomOpWithAttribute", "extension")

def __init__(self, inputs=None, attrs=None):
super().__init__(self)
if attrs is not None or inputs is not None:
self._attrs = attrs
self.set_arguments(inputs)
self.constructor_validate_and_infer_types()
super().__init__(self, inputs)
self._attrs = attrs
# if attrs is not None or inputs is not None:
# self._attrs = attrs
# self.set_arguments(inputs)
# self.constructor_validate_and_infer_types()

def validate_and_infer_types(self):
self.set_output_type(0, self.get_input_element_type(0), self.get_input_partial_shape(0))
Expand Down Expand Up @@ -212,12 +214,9 @@ def validate_and_infer_types(self):
class CustomSimpleOpWithAttribute(Op):
class_type_info = DiscreteTypeInfo("CustomSimpleOpWithAttribute", "extension")

def __init__(self, inputs=None, attrs=None):
super().__init__(self)
def __init__(self, inputs=None, **attrs):
super().__init__(self, inputs)
self._attrs = attrs
if attrs is not None or inputs is not None:
self.set_arguments(inputs)
self.constructor_validate_and_infer_types()

def validate_and_infer_types(self):
self.set_output_type(0, self.get_input_element_type(0), self.get_input_partial_shape(0))
Expand Down Expand Up @@ -249,7 +248,7 @@ def test_op_extension(prepared_paths):
custom_simple = CustomSimpleOp(inputs=[param1, param2])
print("6")
custom_simple.set_friendly_name("test_add")
custom_with_attribute = CustomSimpleOpWithAttribute(inputs=[custom_simple], attrs={"value_str": "test_attribute"})
custom_with_attribute = CustomSimpleOpWithAttribute(inputs=[custom_simple], value_str="test_attribute")
custom_add = CustomAdd(inputs=[custom_with_attribute])
res = ops.result(custom_add, name="result")
simple_model = Model(res, [param1, param2], "SimpleModel")
Expand Down Expand Up @@ -291,4 +290,4 @@ def get_type_info(self):

with pytest.raises(RuntimeError) as e:
core.add_extension(OpWithBadClassTypeInfo)
assert "operation type_info must be an instance of DiscreteTypeInfo, but <class \'str\'> is passed." in str(e.value)
assert "operation type_info must be an instance of DiscreteTypeInfo, but <class 'str'> is passed." in str(e.value)

0 comments on commit 2469726

Please sign in to comment.