From bb8cd13f1f07dd0e664d324500aae1abb8db9e3a Mon Sep 17 00:00:00 2001 From: Boyan-MILANOV Date: Mon, 22 Jul 2024 12:17:01 +0200 Subject: [PATCH 1/4] Add support for list arguments in python calls. Fix int encoding bug --- fickling/fickle.py | 35 ++++++++++++++++++++++++++++++++--- test/test_pickle.py | 22 ++++++++++++++++++++++ 2 files changed, 54 insertions(+), 3 deletions(-) diff --git a/fickling/fickle.py b/fickling/fickle.py index f167c46..a88bfdd 100644 --- a/fickling/fickle.py +++ b/fickling/fickle.py @@ -318,7 +318,7 @@ def encode_body(self) -> bytes: st = self.struct_types[self.num_bytes] if not self.signed: st = st.upper() - return struct.pack(f"{self.endianness.value}{st}") + return struct.pack(f"{self.endianness.value}{st}", self.arg) @classmethod def validate(cls, obj): @@ -415,6 +415,34 @@ def insert(self, index: int, opcode: Opcode): self._ast = None self._properties = None + def _is_constant_type(self, obj: Any) -> bool: + return isinstance(obj, (int, float, str, bytes)) + + def _encode_python_obj(self, obj: Any) -> list[Opcode]: + """Create an opcode sequence that builds an arbitrary python object on the top of the + pickle VM stack""" + if self._is_constant_type(obj): + return [ConstantOpcode.new(obj)] + elif isinstance(obj, list): + res = [Mark()] + for item in obj: + if self._is_constant_type(item): + res.append(ConstantOpcode.new(item)) + else: + res += self._encode_python_obj(item) + res.append(List()) + return res + else: + raise ValueError(f"Type {type(obj)} not supported") + + def insert_python_obj(self, index: int, obj: Any) -> int: + """Insert an opcode sequence that constructs a python object on the stack. + Returns the number of opcodes inserted""" + opcodes = self._encode_python_obj(obj) + for i, opcode in enumerate(opcodes): + self.insert(index+i, opcode) + return len(opcodes) + def insert_python( self, *args, @@ -440,8 +468,9 @@ def insert_python( self.insert(i, Mark()) i += 1 for arg in args: - self.insert(i, ConstantOpcode.new(arg)) - i += 1 + i += self.insert_python_obj(i, arg) + # self.insert(i, ConstantOpcode.new(arg)) + # i += 1 self.insert(i, Tuple()) i += 1 if run_first: diff --git a/test/test_pickle.py b/test/test_pickle.py index e36c470..90a0b5b 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -123,6 +123,28 @@ def test_insert(self): evaluated = loads(loaded.dumps()) self.assertEqual([5, 6, 7, 8], evaluated) + def test_insert_list_arg(self): + pickled = dumps([1, 2, 3, 4]) + loaded = Pickled.load(pickled) + self.assertIsInstance(loaded[-1], fpickle.Stop) + loaded.insert_python( + [1, 2, ['a', 'b'], 3], + module='builtins', + attr='tuple', + use_output_as_unpickle_result=True, + run_first=False, + ) + self.assertIsInstance(loaded[-1], fpickle.Stop) + + # Make sure the injected code cleans up the stack after itself: + interpreter = Interpreter(loaded) + interpreter.run() + self.assertEqual(len(interpreter.stack), 0) + + # Make sure the output is correct + evaluated = loads(loaded.dumps()) + self.assertEqual((1, 2, ['a', 'b'], 3), evaluated) + def test_insert_run_last(self): pickled = dumps([1, 2, 3, 4]) loaded = Pickled.load(pickled) From 3e24d27bd465cbbde7aa3ff2d7221265971614c7 Mon Sep 17 00:00:00 2001 From: Boyan-MILANOV Date: Mon, 22 Jul 2024 13:10:44 +0200 Subject: [PATCH 2/4] Fix typing and lint --- fickling/fickle.py | 4 ++-- test/test_pickle.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/fickling/fickle.py b/fickling/fickle.py index a88bfdd..fa2cdd9 100644 --- a/fickling/fickle.py +++ b/fickling/fickle.py @@ -418,7 +418,7 @@ def insert(self, index: int, opcode: Opcode): def _is_constant_type(self, obj: Any) -> bool: return isinstance(obj, (int, float, str, bytes)) - def _encode_python_obj(self, obj: Any) -> list[Opcode]: + def _encode_python_obj(self, obj: Any) -> List[Opcode]: """Create an opcode sequence that builds an arbitrary python object on the top of the pickle VM stack""" if self._is_constant_type(obj): @@ -440,7 +440,7 @@ def insert_python_obj(self, index: int, obj: Any) -> int: Returns the number of opcodes inserted""" opcodes = self._encode_python_obj(obj) for i, opcode in enumerate(opcodes): - self.insert(index+i, opcode) + self.insert(index + i, opcode) return len(opcodes) def insert_python( diff --git a/test/test_pickle.py b/test/test_pickle.py index 90a0b5b..80a197a 100644 --- a/test/test_pickle.py +++ b/test/test_pickle.py @@ -128,9 +128,9 @@ def test_insert_list_arg(self): loaded = Pickled.load(pickled) self.assertIsInstance(loaded[-1], fpickle.Stop) loaded.insert_python( - [1, 2, ['a', 'b'], 3], - module='builtins', - attr='tuple', + [1, 2, ["a", "b"], 3], + module="builtins", + attr="tuple", use_output_as_unpickle_result=True, run_first=False, ) @@ -143,7 +143,7 @@ def test_insert_list_arg(self): # Make sure the output is correct evaluated = loads(loaded.dumps()) - self.assertEqual((1, 2, ['a', 'b'], 3), evaluated) + self.assertEqual((1, 2, ["a", "b"], 3), evaluated) def test_insert_run_last(self): pickled = dumps([1, 2, 3, 4]) From 88d56834ffdff576dd848ddbe70c0d85038803a2 Mon Sep 17 00:00:00 2001 From: Boyan-MILANOV Date: Mon, 22 Jul 2024 13:22:16 +0200 Subject: [PATCH 3/4] Pin ruff version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 61cf433..e265689 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ requires-python = ">=3.8" [project.optional-dependencies] torch = ["torch >= 2.1.0", "torchvision >= 0.16.1"] -lint = ["black", "mypy", "ruff"] +lint = ["black", "mypy", "ruff==0.5.4"] test = ["pytest", "pytest-cov", "coverage[toml]", "torch >= 2.1.0", "torchvision >= 0.16.1"] dev = ["build", "fickling[lint,test]", "twine", "torch >= 2.1.0", "torchvision >= 0.16.1"] examples = ["numpy", "pytorchfi"] From 6e5d69230148dcccefab67dfea76a60990b88ed1 Mon Sep 17 00:00:00 2001 From: Boyan-MILANOV Date: Mon, 22 Jul 2024 13:43:56 +0200 Subject: [PATCH 4/4] Downgrade ruff --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e265689..2cef975 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ requires-python = ">=3.8" [project.optional-dependencies] torch = ["torch >= 2.1.0", "torchvision >= 0.16.1"] -lint = ["black", "mypy", "ruff==0.5.4"] +lint = ["black", "mypy", "ruff==0.2.0"] test = ["pytest", "pytest-cov", "coverage[toml]", "torch >= 2.1.0", "torchvision >= 0.16.1"] dev = ["build", "fickling[lint,test]", "twine", "torch >= 2.1.0", "torchvision >= 0.16.1"] examples = ["numpy", "pytorchfi"]