Skip to content

Commit

Permalink
feat(serialization): Support serializing only persistent buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
Eta0 committed Oct 19, 2023
1 parent ebd12dd commit ab574af
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 5 deletions.
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,22 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]

### Added

- `TensorSerializer.write_module` now accepts `include_non_persistent_buffers`
as a keyword-only boolean argument that can be set to `False` to exclude
buffers from serialization that were originally registered to the module
through calling `torch.nn.Module.register_buffer` with `persistent=False`
- `torch.nn.Module.state_dict` never includes persistent buffers,
so setting this to `False` will more closely match the behaviour
of `state_dict` serialization
- `TensorSerializer.write_module` used to always include non-persistent
buffers
- The default (`include_non_persistent_buffers=True`) matches the old
behaviour

## [2.5.1] - 2023-10-17

### Changed
Expand Down Expand Up @@ -190,6 +206,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `get_gpu_name`
- `no_init_or_tensor`

[Unreleased]: https://github.com/coreweave/tensorizer/compare/v2.5.1...HEAD
[2.5.1]: https://github.com/coreweave/tensorizer/compare/v2.5.0...v2.5.1
[2.5.0]: https://github.com/coreweave/tensorizer/compare/v2.4.0...v2.5.0
[2.4.0]: https://github.com/coreweave/tensorizer/compare/v2.3.0...v2.4.0
Expand Down
2 changes: 1 addition & 1 deletion tensorizer/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.5.1"
__version__ = "2.6.0.dev0"
61 changes: 57 additions & 4 deletions tensorizer/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
List,
Optional,
Sequence,
Set,
Tuple,
Union,
)
Expand Down Expand Up @@ -767,6 +768,8 @@ def __init__(
self._metadata, self._metadata_raw = _MetadataDeserializer.from_io(
self._file, self._file_header.tensor_count
)
if not self._metadata:
raise ValueError("Tensor index in the file is empty")
# filter_func is a test that determines the tensor names to read.
# If filter_func is None, all tensors are read.
if filter_func is not None:
Expand Down Expand Up @@ -798,7 +801,7 @@ def __init__(
self._plaid_mode_buffer_count = 1
else:
self._plaid_mode_buffer_count = 2
single_largest_tensor = max(tensor_sizes.values())
single_largest_tensor = max(tensor_sizes.values(), default=0)
# Round up to the nearest multiple of the page size
# Just so that more reads happen on page boundaries
single_largest_tensor -= single_largest_tensor % -mmap.PAGESIZE
Expand Down Expand Up @@ -2600,7 +2603,11 @@ def _bulk_write(self, tensors: Iterable[_WriteSpec]):
self._sync_prologue_state()

def write_module(
self, m: torch.nn.Module, remove_tensors: bool = False
self,
m: torch.nn.Module,
remove_tensors: bool = False,
*,
include_non_persistent_buffers: bool = True,
) -> None:
"""
Serializes an entire ``torch.nn.Module`` instance at once,
Expand All @@ -2617,13 +2624,22 @@ def write_module(
remove_tensors: Whether to delete each tensor from `m`
after serializing it.
Deleted tensors are replaced with ``None``.
include_non_persistent_buffers: Whether to serialize buffers
registered with ``persistent=False``.
Set to ``False`` to match the behaviour of
``torch.nn.Module.state_dict()``,
which saves only persistent buffers.
The default may change to ``False`` in a later version.
"""

modules = tuple(m.named_modules())

def extract_tensors():
chain = itertools.chain
repeat = itertools.repeat
callback = None
for idx, (module_name, module) in enumerate(m.named_modules()):
for idx, (module_name, module) in enumerate(modules):
module: torch.nn.Module
parameters = module.named_parameters(recurse=False)
buffers = module.named_buffers(recurse=False)
for (name, tensor), tensor_type in chain(
Expand All @@ -2641,7 +2657,44 @@ def extract_tensors():
callback=callback,
)

self._bulk_write(extract_tensors())
def persistent_buffers() -> Set[str]:
persistent_buffers_set: Set[str] = {
name
for name, _ in m.named_buffers(
recurse=True, remove_duplicate=False
)
}
if hasattr(m, "_non_persistent_buffers_set"):
# Direct access to the _non_persistent_buffers_set attribute
# is an order of magnitude faster than generating
# a state_dict, but this is a private interface, and thus
# not guaranteed to remain stable between torch versions

for module_name, module in modules:
# noinspection PyProtectedMember
persistent_buffers_set.difference_update(
f"{module_name}.{name}"
for name in module._non_persistent_buffers_set
)
else:
# Filtering down to only the buffers that appear
# in the state_dict() representation is the supported way
# to access the persistent buffer list, but is much slower
persistent_buffers_set.intersection_update(
m.state_dict().keys()
)

return persistent_buffers_set

all_tensors = extract_tensors()

if not include_non_persistent_buffers:
persistent = persistent_buffers()
all_tensors = (
spec for spec in all_tensors if spec.name in persistent
)

self._bulk_write(all_tensors)

def write_state_dict(self, state_dict: Dict):
"""
Expand Down
46 changes: 46 additions & 0 deletions tests/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,52 @@ def test_bfloat16(self):

self.assertTrue(torch.equal(tensor, deserialized_tensor))

def test_persistent_buffers(self):
shape = (50, 50)
persistent_buffer = torch.normal(0, 0.5, shape)
non_persistent_buffer = torch.normal(0, 0.5, shape)
nested_module = torch.nn.Module()
nested_module.register_buffer(
"persistent_buffer", persistent_buffer, persistent=True
)
nested_module.register_buffer(
"non_persistent_buffer",
non_persistent_buffer,
persistent=False,
)
module = torch.nn.Module()
module.register_module("nested", nested_module)
model = torch.nn.Module()
model.register_module("module", module)

for include in (True, False):
with self.subTest(
msg=f"Testing include_non_persistent_buffers={include}"
):
tensorized_file = tempfile.NamedTemporaryFile(
"wb+", delete=False
)
try:
serializer = TensorSerializer(tensorized_file)
serializer.write_module(
model, include_non_persistent_buffers=include
)
serializer.close()

with open(tensorized_file.name, "rb") as in_file:
with TensorDeserializer(
in_file, device="cpu", lazy_load=True
) as deserializer:
assertion = (
self.assertIn if include else self.assertNotIn
)
assertion(
"module.nested.non_persistent_buffer",
deserializer.keys(),
)
finally:
os.unlink(tensorized_file.name)


class TestDeserialization(unittest.TestCase):
_serialized_model_path: str
Expand Down

0 comments on commit ab574af

Please sign in to comment.