diff --git a/tests/BUILD b/tests/BUILD index 90a8913a1825..ba725842b36c 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -416,11 +416,6 @@ jax_multiplatform_test( jax_multiplatform_test( name = "custom_root_test", srcs = ["custom_root_test.py"], - shard_count = { - "cpu": 10, - "gpu": 10, - "tpu": 10, - }, ) jax_multiplatform_test(