Skip to content

Commit

Permalink
Initial updates for et
Browse files Browse the repository at this point in the history
  • Loading branch information
cavusmustafa committed Jan 16, 2025
1 parent 40b19c8 commit 21b4c39
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
37 changes: 35 additions & 2 deletions src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,12 @@ def __init__(self, pt_module, fx_gm=None, nodes=None,
found_types.append(
OVAny(pt_to_ov_type_map[str(value.meta['tensor_meta'].dtype)]))
else:
found_shapes.append(None)
found_types.append(None)
if hasattr(value, "meta") and ('val' in value.meta.keys()):
found_shapes.append(value.meta["val"].shape)
found_types.append(None)
else:
found_shapes.append(None)
found_types.append(None)
elif value.op == 'output':
# Instead of putting output index, refer to its target
uargs = self.unpack_containers(value.args)
Expand Down Expand Up @@ -453,3 +457,32 @@ def mark_node(self, node):
node.set_friendly_name(name)
super().mark_node(node)
return node


class ExecuTorchPythonDecoder (TorchFXPythonDecoder):

def __init__(self, pt_module, fx_gm=None, nodes=None,
mark_node_callback=None, input_shapes=[], input_types=[], dynamic_shapes=False):
TorchFXPythonDecoder.__init__(self, pt_module, fx_gm, nodes, mark_node_callback, input_shapes, input_types, dynamic_shapes)

def visit_subgraph(self, node_visitor):
# make sure topological order is satisfied
for node in self._nodes:
if node.op == 'placeholder' or node.op == 'output':
continue # skipping non-operational nodes
if node.op == 'call_function' and str(node.target) in ["aten._assert_async.msg"]:
continue
decoder = ExecuTorchPythonDecoder(
node, self.fx_gm, self._nodes, mark_node_callback=self.mark_node_callback)
self.m_decoders.append(decoder)
node_visitor(decoder)

def get_op_type(self):
if "getitem" in str(self.pt_module.target):
return str(self.pt_module.target)
elif self.pt_module.op == 'call_function':
return self.pt_module.target.__name__
elif self.pt_module.op == 'get_attr':
return 'get_attr' # FIXME should be aligned with get_attr from TS implementation
else:
return 'UNKNOWN_TYPE_' + str(self.pt_module.op)
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.fx import GraphModule

from openvino.frontend import FrontEndManager
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder
from openvino.frontend.pytorch.fx_decoder import TorchFXPythonDecoder, ExecuTorchPythonDecoder
from openvino import Core, Type, PartialShape, serialize
from openvino.frontend.pytorch.torchdynamo.backend_utils import _get_cache_dir, _get_device, _get_config, _is_cache_dir_in_config

Expand Down Expand Up @@ -78,7 +78,7 @@ def openvino_compile_cached_model(cached_model_path, options, *example_inputs):

return compiled_model

def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, options=None):
def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, options=None, executorch=False):
core = Core()

device = _get_device(options)
Expand All @@ -101,7 +101,10 @@ def openvino_compile(gm: GraphModule, *args, model_hash_str: str = None, options
input_types.append(input_data.type())
input_shapes.append(input_data.size())

decoder = TorchFXPythonDecoder(gm)
if executorch:
decoder = ExecuTorchPythonDecoder(gm)
else:
decoder = TorchFXPythonDecoder(gm)

im = fe.load(decoder)

Expand Down

0 comments on commit 21b4c39

Please sign in to comment.