From 870fefcce2a1fd33d46b40f8f762376d7aa6c740 Mon Sep 17 00:00:00 2001 From: Qianfeng Zhang Date: Sat, 21 Dec 2024 09:04:32 +0000 Subject: [PATCH] Fix the scripts to pass flake8 checking --- tests/test_mem_eff_attention.py | 4 ++-- xformers/ops/fmha/ck.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_mem_eff_attention.py b/tests/test_mem_eff_attention.py index 14027a164..7a4f36d01 100644 --- a/tests/test_mem_eff_attention.py +++ b/tests/test_mem_eff_attention.py @@ -672,8 +672,8 @@ def test_backward( if op_bw == fmha.ck.BwOp: op_fw = fmha.ck.FwOp if dtype == torch.bfloat16: - ## bfloat16 testing can be enabled by export ENABLE_HIP_FMHA_RTN_BF16_CONVERT=1 when - ## building xformers and get accurate results + # bfloat16 testing can be enabled by export ENABLE_HIP_FMHA_RTN_BF16_CONVERT=1 when + # building xformers and get accurate results pytest.skip( "CK Fmha backward for bfloat16 currently is not very accurate for some cases!" ) diff --git a/xformers/ops/fmha/ck.py b/xformers/ops/fmha/ck.py index 06787b80f..5d1137711 100644 --- a/xformers/ops/fmha/ck.py +++ b/xformers/ops/fmha/ck.py @@ -6,7 +6,6 @@ from dataclasses import replace from enum import Enum -from functools import partial from typing import Any, Iterable, List, Mapping, Optional, Set, Tuple, Union import torch @@ -38,7 +37,6 @@ Context, Gradients, Inputs, - _attn_bias_apply, check_lastdim_alignment_stride1, ) @@ -218,7 +216,7 @@ def apply( assert inp.query.ndim == 5, f"query has shape {inp.query.shape}" ctx: Optional[Context] = None - ## consider for expanded 5-D inputted + # consider for expanded 5-D inputted if inp.key.stride()[3] == 0: assert ( inp.value.stride()[3] == 0