From a8746f3bb3bab42aa6064b5e4ee9ba80ce0a2bcf Mon Sep 17 00:00:00 2001 From: dan_the_3rd <43445237+danthe3rd@users.noreply.github.com> Date: Fri, 27 Dec 2024 12:20:08 +0000 Subject: [PATCH] Fix .to() method for all attention biases (fairinternal/xformers#1278) __original_commit__ = fairinternal/xformers@0f7311b132dbc18482e233d43c526b101bc08bf6 --- tests/test_mem_eff_attention.py | 2 + xformers/ops/fmha/attn_bias.py | 115 ++++++++++++++++++++++++++++++++ 2 files changed, 117 insertions(+) diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 536a9c7cfb..6c59f4e4ac 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -462,6 +462,8 @@ def test_forward(opFW_device_dtype_biasT_B_Mq_Mkv_H_K_Kv, packed, fmt, **kwargs) fmt="BMHK" if packed else fmt, **kwargs, ) + if attn_bias is not None: + assert type(attn_bias.to(query.device)) is type(attn_bias) if packed: c = torch.stack([query, key, value], 2) diff --git a/xformers/ops/fmha/attn_bias.py b/xformers/ops/fmha/attn_bias.py index 1466bb240f..b72c553d6b 100644 --- a/xformers/ops/fmha/attn_bias.py +++ b/xformers/ops/fmha/attn_bias.py @@ -169,6 +169,9 @@ class LocalAttentionFromBottomRightMask(AttentionBias): window_left: int window_right: int + def to(self, device) -> "LocalAttentionFromBottomRightMask": + return self + def __post_init__(self) -> None: if self.window_left < 0: raise ValueError( @@ -227,6 +230,9 @@ class LowerTriangularFromBottomRightMask(AttentionBias): """ def to(self, device: torch.device) -> "LowerTriangularFromBottomRightMask": + assert ( + type(self) is LowerTriangularFromBottomRightMask + ), "Please implement in subclass" return self def materialize( @@ -273,6 +279,14 @@ class LowerTriangularFromBottomRightLocalAttentionMask( _window_size: int + def to( + self, device: torch.device + ) -> "LowerTriangularFromBottomRightLocalAttentionMask": + assert ( + type(self) is LowerTriangularFromBottomRightLocalAttentionMask + ), "Please implement in subclass" + return self + def __post_init__(self) -> None: if self._window_size <= 0: raise ValueError( @@ -314,6 +328,7 @@ class _SeqLenInfo: seqstart_py: List[int] def to(self, device: torch.device) -> "_SeqLenInfo": + assert type(self) is _SeqLenInfo, "Please implement in subclass" if self.seqstart.device == device: return self return _SeqLenInfo( @@ -437,6 +452,7 @@ def __post_init__(self) -> None: assert len(self.seqstart_py) == len(self.seqlen_py) + 1 def to(self, device: torch.device) -> "_PaddedSeqLenInfo": + assert type(self) is _PaddedSeqLenInfo, "Please implement in subclass" if self.seqlen.device == device: return self return _PaddedSeqLenInfo( @@ -552,6 +568,7 @@ class _GappySeqInfo(_SeqLenInfo): # seqstart: torch.Tensor def to(self, device: torch.device) -> "_GappySeqInfo": + assert type(self) is _GappySeqInfo, "Please implement in subclass" if self.seqlen.device == device: return self return _GappySeqInfo( @@ -654,6 +671,7 @@ class BlockDiagonalMask(AttentionBias): _batch_sizes: Optional[Sequence[int]] = None def to(self, device) -> "BlockDiagonalMask": + assert type(self) is BlockDiagonalMask, "Please implement in subclass" return BlockDiagonalMask( q_seqinfo=self.q_seqinfo.to(device), k_seqinfo=self.k_seqinfo.to(device), @@ -858,6 +876,14 @@ class BlockDiagonalCausalMask(BlockDiagonalMask): is from the initial query in block i. """ + def to(self, device) -> "BlockDiagonalCausalMask": + assert type(self) is BlockDiagonalCausalMask, "Please implement in subclass" + return BlockDiagonalCausalMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + _batch_sizes=self._batch_sizes, + ) + def _create_block_mask( self, shape: Tuple[int, ...], @@ -885,6 +911,16 @@ class BlockDiagonalCausalFromBottomRightMask(BlockDiagonalMask): final query in block i. """ + def to(self, device) -> "BlockDiagonalCausalFromBottomRightMask": + assert ( + type(self) is BlockDiagonalCausalFromBottomRightMask + ), "Please implement in subclass" + return BlockDiagonalCausalFromBottomRightMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + _batch_sizes=self._batch_sizes, + ) + def __post_init__(self) -> None: for i, ((q_start, q_end), (k_start, k_end)) in enumerate( zip( @@ -933,6 +969,7 @@ class BlockDiagonalPaddedKeysMask(AttentionBias): k_seqinfo: _PaddedSeqLenInfo def to(self, device) -> "BlockDiagonalPaddedKeysMask": + assert type(self) is BlockDiagonalPaddedKeysMask, "Please implement in subclass" return BlockDiagonalPaddedKeysMask( q_seqinfo=self.q_seqinfo.to(device), k_seqinfo=self.k_seqinfo.to(device), @@ -1044,6 +1081,15 @@ class BlockDiagonalCausalWithOffsetPaddedKeysMask(BlockDiagonalPaddedKeysMask): causal_diagonal: Any = None # unused. Exists for BC only. + def to(self, device) -> "BlockDiagonalCausalWithOffsetPaddedKeysMask": + assert ( + type(self) is BlockDiagonalCausalWithOffsetPaddedKeysMask + ), "Please implement in subclass" + return BlockDiagonalCausalWithOffsetPaddedKeysMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + ) + def _create_block_mask( self, shape: Tuple[int, ...], @@ -1103,6 +1149,16 @@ class BlockDiagonalCausalLocalAttentionPaddedKeysMask(BlockDiagonalPaddedKeysMas _window_size: int + def to(self, device) -> "BlockDiagonalCausalLocalAttentionPaddedKeysMask": + assert ( + type(self) is BlockDiagonalCausalLocalAttentionPaddedKeysMask + ), "Please implement in subclass" + return BlockDiagonalCausalLocalAttentionPaddedKeysMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + _window_size=self._window_size, + ) + def _create_block_mask( self, shape: Tuple[int, ...], @@ -1153,6 +1209,9 @@ class PagedBlockDiagonalPaddedKeysMask(AttentionBias): ] = BlockDiagonalPaddedKeysMask def to(self, device: torch.device) -> "PagedBlockDiagonalPaddedKeysMask": + assert ( + type(self) is PagedBlockDiagonalPaddedKeysMask + ), "Please implement in subclass" return PagedBlockDiagonalPaddedKeysMask( q_seqinfo=self.q_seqinfo.to(device), k_seqinfo=self.k_seqinfo.to(device), @@ -1250,6 +1309,19 @@ class PagedBlockDiagonalCausalWithOffsetPaddedKeysMask( _UNPAGED_TYPE = BlockDiagonalCausalWithOffsetPaddedKeysMask + def to( + self, device: torch.device + ) -> "PagedBlockDiagonalCausalWithOffsetPaddedKeysMask": + assert ( + type(self) is PagedBlockDiagonalCausalWithOffsetPaddedKeysMask + ), "Please implement in subclass" + return PagedBlockDiagonalCausalWithOffsetPaddedKeysMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + block_tables=self.block_tables.to(device), + page_size=self.page_size, + ) + @dataclass class BlockDiagonalGappyKeysMask(AttentionBias): @@ -1264,6 +1336,7 @@ class BlockDiagonalGappyKeysMask(AttentionBias): k_seqinfo: _GappySeqInfo def to(self, device: torch.device) -> "BlockDiagonalGappyKeysMask": + assert type(self) is BlockDiagonalGappyKeysMask, "Please implement in subclass" return BlockDiagonalGappyKeysMask( q_seqinfo=self.q_seqinfo.to(device), k_seqinfo=self.k_seqinfo.to(device), @@ -1359,6 +1432,15 @@ class BlockDiagonalCausalWithOffsetGappyKeysMask(BlockDiagonalGappyKeysMask): than Q is to the final query in block i. """ + def to(self, device: torch.device) -> "BlockDiagonalCausalWithOffsetGappyKeysMask": + assert ( + type(self) is BlockDiagonalCausalWithOffsetGappyKeysMask + ), "Please implement in subclass" + return BlockDiagonalCausalWithOffsetGappyKeysMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + ) + def materialize( self, shape: Tuple[int, ...], @@ -1407,6 +1489,17 @@ class PagedBlockDiagonalGappyKeysMask(AttentionBias): Type[BlockDiagonalGappyKeysMask] ] = BlockDiagonalGappyKeysMask + def to(self, device: torch.device) -> "PagedBlockDiagonalGappyKeysMask": + assert ( + type(self) is PagedBlockDiagonalGappyKeysMask + ), "Please implement in subclass" + return PagedBlockDiagonalGappyKeysMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + block_tables=self.block_tables.to(device), + page_size=self.page_size, + ) + def materialize( self, shape: Tuple[int, ...], @@ -1507,6 +1600,17 @@ class BlockDiagonalCausalLocalAttentionMask(BlockDiagonalCausalMask): _window_size: int = 0 # forced due to inheritance and default arguments + def to(self, device) -> "BlockDiagonalCausalLocalAttentionMask": + assert ( + type(self) is BlockDiagonalCausalLocalAttentionMask + ), "Please implement in subclass" + return BlockDiagonalCausalLocalAttentionMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + _batch_sizes=self._batch_sizes, + _window_size=self._window_size, + ) + def __post_init__(self): if self._window_size <= 0: raise ValueError( @@ -1561,6 +1665,17 @@ class BlockDiagonalCausalLocalAttentionFromBottomRightMask( _window_size: int = 0 # forced due to inheritance and default arguments + def to(self, device) -> "BlockDiagonalCausalLocalAttentionFromBottomRightMask": + assert ( + type(self) is BlockDiagonalCausalLocalAttentionFromBottomRightMask + ), "Please implement in subclass" + return BlockDiagonalCausalLocalAttentionFromBottomRightMask( + q_seqinfo=self.q_seqinfo.to(device), + k_seqinfo=self.k_seqinfo.to(device), + _batch_sizes=self._batch_sizes, + _window_size=self._window_size, + ) + def __post_init__(self): super().__post_init__() if self._window_size <= 0: