diff --git a/tests/trainer/test_fsdp_checkpoint.py b/tests/trainer/test_fsdp_checkpoint.py index 9f785a94ff..332ed5b7b7 100644 --- a/tests/trainer/test_fsdp_checkpoint.py +++ b/tests/trainer/test_fsdp_checkpoint.py @@ -514,6 +514,7 @@ def test_fsdp_mixed_with_sync( '0.23.0', '0.24.0', '0.25.0', + '0.26.0', ], ) @pytest.mark.filterwarnings(r'ignore:.*metrics are not saved with sharded state dict.*:UserWarning') @@ -534,8 +535,8 @@ def test_fsdp_load_old_checkpoint( pytest.skip('TODO: This checkpoint is missing') if (composer_version in ['0.22.0', '0.23.0'] and version.parse(torch.__version__) < version.parse('2.3.0')) or ( - composer_version == '0.24.0' and version.parse(torch.__version__) < version.parse('2.4.0') - ) or (composer_version == '0.25.0' and version.parse(torch.__version__) < version.parse('2.5.0')): + composer_version in ['0.24.0', '0.25.0'] and version.parse(torch.__version__) < version.parse('2.4.0') + ) or (composer_version in '0.26.0' and version.parse(torch.__version__) < version.parse('2.5.0')): pytest.skip('Current torch version is older than torch version that checkpoint was written with.') if composer_version in ['0.13.5', '0.14.0', '0.14.1', '0.15.1']: