diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 14027a164..7a4f36d01 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -672,8 +672,8 @@ def test_backward( if op_bw == fmha.ck.BwOp: op_fw = fmha.ck.FwOp if dtype == torch.bfloat16: - ## bfloat16 testing can be enabled by export ENABLE_HIP_FMHA_RTN_BF16_CONVERT=1 when - ## building xformers and get accurate results + # bfloat16 testing can be enabled by export ENABLE_HIP_FMHA_RTN_BF16_CONVERT=1 when + # building xformers and get accurate results pytest.skip( "CK Fmha backward for bfloat16 currently is not very accurate for some cases!" ) diff --git a/xformers/ops/fmha/ck.py b/xformers/ops/fmha/ck.py index 06787b80f..5d1137711 100644 --- a/xformers/ops/fmha/ck.py +++ b/xformers/ops/fmha/ck.py @@ -6,7 +6,6 @@ from dataclasses import replace from enum import Enum -from functools import partial from typing import Any, Iterable, List, Mapping, Optional, Set, Tuple, Union import torch @@ -38,7 +37,6 @@ Context, Gradients, Inputs, - _attn_bias_apply, check_lastdim_alignment_stride1, ) @@ -218,7 +216,7 @@ def apply( assert inp.query.ndim == 5, f"query has shape {inp.query.shape}" ctx: Optional[Context] = None - ## consider for expanded 5-D inputted + # consider for expanded 5-D inputted if inp.key.stride()[3] == 0: assert ( inp.value.stride()[3] == 0