From ab574af6d7050db95fe3c7cdedf4f425960939bd Mon Sep 17 00:00:00 2001 From: Eta Date: Thu, 19 Oct 2023 18:28:30 -0500 Subject: [PATCH] feat(serialization): Support serializing only persistent buffers --- CHANGELOG.md | 17 +++++++++++ tensorizer/_version.py | 2 +- tensorizer/serialization.py | 61 ++++++++++++++++++++++++++++++++++--- tests/test_serialization.py | 46 ++++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5322c770..4d00e48e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,22 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Added + +- `TensorSerializer.write_module` now accepts `include_non_persistent_buffers` + as a keyword-only boolean argument that can be set to `False` to exclude + buffers from serialization that were originally registered to the module + through calling `torch.nn.Module.register_buffer` with `persistent=False` + - `torch.nn.Module.state_dict` never includes persistent buffers, + so setting this to `False` will more closely match the behaviour + of `state_dict` serialization + - `TensorSerializer.write_module` used to always include non-persistent + buffers + - The default (`include_non_persistent_buffers=True`) matches the old + behaviour + ## [2.5.1] - 2023-10-17 ### Changed @@ -190,6 +206,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `get_gpu_name` - `no_init_or_tensor` +[Unreleased]: https://github.com/coreweave/tensorizer/compare/v2.5.1...HEAD [2.5.1]: https://github.com/coreweave/tensorizer/compare/v2.5.0...v2.5.1 [2.5.0]: https://github.com/coreweave/tensorizer/compare/v2.4.0...v2.5.0 [2.4.0]: https://github.com/coreweave/tensorizer/compare/v2.3.0...v2.4.0 diff --git a/tensorizer/_version.py b/tensorizer/_version.py index 7a2056f5..d4fe316b 100644 --- a/tensorizer/_version.py +++ b/tensorizer/_version.py @@ -1 +1 @@ -__version__ = "2.5.1" +__version__ = "2.6.0.dev0" diff --git a/tensorizer/serialization.py b/tensorizer/serialization.py index e3be1d4e..b2edffb4 100644 --- a/tensorizer/serialization.py +++ b/tensorizer/serialization.py @@ -34,6 +34,7 @@ List, Optional, Sequence, + Set, Tuple, Union, ) @@ -767,6 +768,8 @@ def __init__( self._metadata, self._metadata_raw = _MetadataDeserializer.from_io( self._file, self._file_header.tensor_count ) + if not self._metadata: + raise ValueError("Tensor index in the file is empty") # filter_func is a test that determines the tensor names to read. # If filter_func is None, all tensors are read. if filter_func is not None: @@ -798,7 +801,7 @@ def __init__( self._plaid_mode_buffer_count = 1 else: self._plaid_mode_buffer_count = 2 - single_largest_tensor = max(tensor_sizes.values()) + single_largest_tensor = max(tensor_sizes.values(), default=0) # Round up to the nearest multiple of the page size # Just so that more reads happen on page boundaries single_largest_tensor -= single_largest_tensor % -mmap.PAGESIZE @@ -2600,7 +2603,11 @@ def _bulk_write(self, tensors: Iterable[_WriteSpec]): self._sync_prologue_state() def write_module( - self, m: torch.nn.Module, remove_tensors: bool = False + self, + m: torch.nn.Module, + remove_tensors: bool = False, + *, + include_non_persistent_buffers: bool = True, ) -> None: """ Serializes an entire ``torch.nn.Module`` instance at once, @@ -2617,13 +2624,22 @@ def write_module( remove_tensors: Whether to delete each tensor from `m` after serializing it. Deleted tensors are replaced with ``None``. + include_non_persistent_buffers: Whether to serialize buffers + registered with ``persistent=False``. + Set to ``False`` to match the behaviour of + ``torch.nn.Module.state_dict()``, + which saves only persistent buffers. + The default may change to ``False`` in a later version. """ + modules = tuple(m.named_modules()) + def extract_tensors(): chain = itertools.chain repeat = itertools.repeat callback = None - for idx, (module_name, module) in enumerate(m.named_modules()): + for idx, (module_name, module) in enumerate(modules): + module: torch.nn.Module parameters = module.named_parameters(recurse=False) buffers = module.named_buffers(recurse=False) for (name, tensor), tensor_type in chain( @@ -2641,7 +2657,44 @@ def extract_tensors(): callback=callback, ) - self._bulk_write(extract_tensors()) + def persistent_buffers() -> Set[str]: + persistent_buffers_set: Set[str] = { + name + for name, _ in m.named_buffers( + recurse=True, remove_duplicate=False + ) + } + if hasattr(m, "_non_persistent_buffers_set"): + # Direct access to the _non_persistent_buffers_set attribute + # is an order of magnitude faster than generating + # a state_dict, but this is a private interface, and thus + # not guaranteed to remain stable between torch versions + + for module_name, module in modules: + # noinspection PyProtectedMember + persistent_buffers_set.difference_update( + f"{module_name}.{name}" + for name in module._non_persistent_buffers_set + ) + else: + # Filtering down to only the buffers that appear + # in the state_dict() representation is the supported way + # to access the persistent buffer list, but is much slower + persistent_buffers_set.intersection_update( + m.state_dict().keys() + ) + + return persistent_buffers_set + + all_tensors = extract_tensors() + + if not include_non_persistent_buffers: + persistent = persistent_buffers() + all_tensors = ( + spec for spec in all_tensors if spec.name in persistent + ) + + self._bulk_write(all_tensors) def write_state_dict(self, state_dict: Dict): """ diff --git a/tests/test_serialization.py b/tests/test_serialization.py index b28c767e..e96d73f1 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -251,6 +251,52 @@ def test_bfloat16(self): self.assertTrue(torch.equal(tensor, deserialized_tensor)) + def test_persistent_buffers(self): + shape = (50, 50) + persistent_buffer = torch.normal(0, 0.5, shape) + non_persistent_buffer = torch.normal(0, 0.5, shape) + nested_module = torch.nn.Module() + nested_module.register_buffer( + "persistent_buffer", persistent_buffer, persistent=True + ) + nested_module.register_buffer( + "non_persistent_buffer", + non_persistent_buffer, + persistent=False, + ) + module = torch.nn.Module() + module.register_module("nested", nested_module) + model = torch.nn.Module() + model.register_module("module", module) + + for include in (True, False): + with self.subTest( + msg=f"Testing include_non_persistent_buffers={include}" + ): + tensorized_file = tempfile.NamedTemporaryFile( + "wb+", delete=False + ) + try: + serializer = TensorSerializer(tensorized_file) + serializer.write_module( + model, include_non_persistent_buffers=include + ) + serializer.close() + + with open(tensorized_file.name, "rb") as in_file: + with TensorDeserializer( + in_file, device="cpu", lazy_load=True + ) as deserializer: + assertion = ( + self.assertIn if include else self.assertNotIn + ) + assertion( + "module.nested.non_persistent_buffer", + deserializer.keys(), + ) + finally: + os.unlink(tensorized_file.name) + class TestDeserialization(unittest.TestCase): _serialized_model_path: str