From bbe73a4a16527bbdfa89e13b797881e54e12b7c8 Mon Sep 17 00:00:00 2001 From: MayDomine <1583143678@qq.com> Date: Thu, 7 Jul 2022 18:04:21 +0800 Subject: [PATCH] add __iter__ to make TransformerBlockList Iterable --- bmtrain/block_layer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/bmtrain/block_layer.py b/bmtrain/block_layer.py index 561b0039..cd792a88 100644 --- a/bmtrain/block_layer.py +++ b/bmtrain/block_layer.py @@ -831,7 +831,8 @@ def __init__(self, modules: Iterable[CheckpointBlock], sqrt=False) -> None: def __len__(self) -> int: return len(self._modules) - + def __iter__(self) -> Iterator[CheckpointBlock]: + return iter(self._modules.values()) def __getitem__(self, index: Union[int, str]) -> CheckpointBlock: return self._modules[str(index)]