diff --git a/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py b/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py index e5ffdef174618d..b44d0a976d6f4c 100644 --- a/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py +++ b/src/bindings/python/src/openvino/frontend/pytorch/fx_decoder.py @@ -192,12 +192,8 @@ def __init__(self, pt_module, fx_gm=None, nodes=None, found_types.append( OVAny(pt_to_ov_type_map[str(value.meta['tensor_meta'].dtype)])) else: - if hasattr(value, "meta") and ('val' in value.meta.keys()): - found_shapes.append(value.meta["val"].shape) - found_types.append(None) - else: - found_shapes.append(None) - found_types.append(None) + found_shapes.append(None) + found_types.append(None) elif value.op == 'output': # Instead of putting output index, refer to its target uargs = self.unpack_containers(value.args) @@ -461,9 +457,80 @@ def mark_node(self, node): class ExecuTorchPythonDecoder (TorchFXPythonDecoder): + # TODO: The constructor of ExecuTorchPythonDecoder is mostly similar to the + # constructor TorchFXTorchPythonDecoder. Update this to utilize a common + # implementation. def __init__(self, pt_module, fx_gm=None, nodes=None, mark_node_callback=None, input_shapes=[], input_types=[], dynamic_shapes=False): - TorchFXPythonDecoder.__init__(self, pt_module, fx_gm, nodes, mark_node_callback, input_shapes, input_types, dynamic_shapes) + super().__init__(mark_node_callback) + self.pt_module = pt_module + self.fx_gm = fx_gm if fx_gm is not None else pt_module + self.input_types = [OVAny(pt_to_ov_type_map[str(t)]) + for t in input_types] + self.input_shapes = input_shapes + + self._input_signature = [] + self._example_input = None + + if issubclass(type(pt_module), torch.fx.graph_module.GraphModule): + self._input_is_list = None + self._nodes = list(pt_module.graph.nodes) + found_types = [] + found_shapes = [] + for i, value in enumerate(self._nodes): + if value.op == 'placeholder': + self._inputs.append(i) + self._input_signature.append(value.name) + if hasattr(value, "meta") and ('tensor_meta' in value.meta.keys()) and value.meta['tensor_meta']: + found_shapes.append(value.meta['tensor_meta'].shape) + found_types.append( + OVAny(pt_to_ov_type_map[str(value.meta['tensor_meta'].dtype)])) + else: + if hasattr(value, "meta") and ('val' in value.meta.keys()): + found_shapes.append(value.meta["val"].shape) + found_types.append(None) + else: + found_shapes.append(None) + found_types.append(None) + elif value.op == 'output': + # Instead of putting output index, refer to its target + uargs = self.unpack_containers(value.args) + self._outputs = [(arg[0], self._nodes.index(arg[1])) + for arg in uargs if arg[1] is not None] + for idx, shape in enumerate(found_shapes): + if shape is not None: + new_shape = [] + for dim in shape: + if (dynamic_shapes or type(dim).__name__ == "SymInt"): + new_shape.append(-1) + else: + new_shape.append(dim) + found_shapes[idx] = torch.Size(new_shape) + + if not input_shapes or len(input_shapes) == 0: + self.input_shapes = found_shapes + if not input_types or len(input_types) == 0: + self.input_types = found_types + + if hasattr(pt_module, "forward"): + input_params = inspect.signature(pt_module.forward).parameters + self._input_signature = list(input_params) + + elif issubclass(type(pt_module), torch.fx.Node): + self._nodes = nodes # passed from outer context + + # FIXME: Quadratic complexity nodes*nodes considering the outer loop over all nodes + self._outputs = [("", self._nodes.index(pt_module))] + + self.input_types = [] + for arg in pt_module.args: + if isinstance(arg, torch.fx.Node): + self._inputs.append(self._nodes.index(arg)) + else: + # Not a node, consider it inlined + self._inputs.append(InlinedInput(arg)) + self.input_types.append( + BaseFXDecoder.get_type_for_value(arg)) def visit_subgraph(self, node_visitor): # make sure topological order is satisfied