diff --git a/jax/experimental/mesh_utils.py b/jax/experimental/mesh_utils.py index 0807535ebf0d..faf48a841f07 100644 --- a/jax/experimental/mesh_utils.py +++ b/jax/experimental/mesh_utils.py @@ -137,8 +137,6 @@ def _create_device_mesh_for_nd_torus( physical topology. mesh_shape: shape of the logical mesh (size of the various logical parallelism axes), with axes ordered by increasing network intensity. - prefer_symmetric: whether to prefer to assign a logical axis to multiple - physical axes of the same size rather than axes of different sizes. Returns: An np.ndarray of devices in the shape of the logical mesh (mesh_shape), with