Skip to content

Commit

Permalink
Fixing batch_dim_name attribute (#20674)
Browse files Browse the repository at this point in the history
* fixing wrong trainer assumption that batch dim is always the first one in the mesh

* need functools partial

* lint

* fix test failure when distribution=None

* lint2

* fix for test failure

* added data sharding for 3D+ meshes

* lint3

* added @Property for batch_dim_name + refactoring

* fix typo
  • Loading branch information
martin-gorner authored Jan 7, 2025
1 parent ab3c8f5 commit fbf0af7
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 18 deletions.
12 changes: 8 additions & 4 deletions keras/src/backend/jax/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def distribute_tensor(tensor, layout):
return global_value


def distribute_data_input(per_process_batch, layout):
def distribute_data_input(per_process_batch, layout, batch_dim_name):
"""Distribute the input data with the corresponding layout.
Note that the inputs here is a local worker batch. Within the local worker,
Expand All @@ -117,9 +117,13 @@ def distribute_data_input(per_process_batch, layout):
if not isinstance(layout, jax.sharding.Sharding):
layout = _to_jax_layout(layout)

mesh_shape = list(layout.mesh.shape.values())
num_model_replicas_total = mesh_shape[0] # batch dimension of the mesh
mesh_model_dim_size = mesh_shape[1] if len(mesh_shape) > 1 else 1
num_model_replicas_total = layout.mesh.shape[batch_dim_name]

mesh_model_dim_size = 1
for name, dim_size in layout.mesh.shape.items():
if not name == batch_dim_name:
mesh_model_dim_size *= dim_size

num_model_replicas_per_process = num_model_replicas_total / num_processes()
per_process_batch_size = per_process_batch.shape[0]

Expand Down
4 changes: 3 additions & 1 deletion keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,9 @@ def test_distribute_data_input(self):
mesh, jax.sharding.PartitionSpec("batch", None)
)

result = backend_dlib.distribute_data_input(per_process_batch, layout)
result = backend_dlib.distribute_data_input(
per_process_batch, layout, "batch"
)

# Check the shape of the global batch array
self.assertEqual(
Expand Down
8 changes: 6 additions & 2 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import collections
import itertools
from functools import partial

import jax
import numpy as np
Expand Down Expand Up @@ -988,15 +989,18 @@ def _get_jax_state(

def _distribute_data(data, layouts=None):
distribution = distribution_lib.distribution()

if distribution is not None:
if layouts is None:
layouts = tree.map_structure(
lambda d: distribution.get_data_layout(d.shape),
data,
)
return tree.map_structure(
jax_distribution_lib.distribute_data_input, data, layouts
jax_dist_data_input = partial(
jax_distribution_lib.distribute_data_input,
batch_dim_name=distribution.batch_dim_name,
)
return tree.map_structure(jax_dist_data_input, data, layouts)

return tree.map_structure(jax.device_put, data)

Expand Down
20 changes: 12 additions & 8 deletions keras/src/distribution/distribution_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,9 @@ class Distribution:
device_mesh: A `DeviceMesh` instance.
"""

def __init__(self, device_mesh):
def __init__(self, device_mesh, batch_dim_name=None):
self._device_mesh = device_mesh
self._batch_dim_name = batch_dim_name

def get_data_layout(self, data_shape):
"""Retrieve the `TensorLayout` for the input data.
Expand Down Expand Up @@ -341,6 +342,10 @@ def scope(self):
def device_mesh(self):
return self._device_mesh

@property
def batch_dim_name(self):
return self._batch_dim_name

def distribute_dataset(self, dataset):
"""Create a distributed dataset instance from the original user dataset.
Expand Down Expand Up @@ -395,7 +400,6 @@ def __init__(self, device_mesh=None, devices=None, auto_shard_dataset=True):
else:
self._initialize_mesh_from_list_devices()

self._batch_dim_name = self.device_mesh.axis_names[0]
# Those following attributes might get convert to public methods.
self._num_process = distribution_lib.num_processes()
self._process_id = distribution_lib.process_id()
Expand All @@ -408,7 +412,7 @@ def _initialize_with_device_mesh(self, device_mesh):
"Expect `mesh` to be an instance of `DeviceMesh`. "
f"Received: mesh={device_mesh} (of type {type(device_mesh)})"
)
super().__init__(device_mesh)
super().__init__(device_mesh, device_mesh.axis_names[0])
if self.device_mesh.devices.ndim != 1:
warnings.warn(
"Expect the input mesh to be 1D, but received "
Expand All @@ -424,7 +428,7 @@ def _initialize_mesh_from_devices(self, devices):
axis_names=[DEFAULT_BATCH_DIM_NAME],
devices=devices,
)
super().__init__(device_mesh)
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)

def _initialize_mesh_from_list_devices(self):
devices = np.array(list_devices())
Expand All @@ -433,11 +437,11 @@ def _initialize_mesh_from_list_devices(self):
axis_names=[DEFAULT_BATCH_DIM_NAME],
devices=devices,
)
super().__init__(device_mesh)
super().__init__(device_mesh, DEFAULT_BATCH_DIM_NAME)

def get_data_layout(self, data_shape):
data_shard_spec = [None] * len(data_shape)
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
Expand Down Expand Up @@ -590,7 +594,7 @@ def __init__(self, *, layout_map=None, batch_dim_name=None, **kwargs):

def get_data_layout(self, data_shape):
data_shard_spec = [None] * len(data_shape)
data_shard_spec[0] = self._batch_dim_name # Shard on the first dim
data_shard_spec[0] = self.batch_dim_name # Shard on the first dim
return TensorLayout(data_shard_spec, self.device_mesh)

def get_variable_layout(self, variable):
Expand Down Expand Up @@ -631,7 +635,7 @@ def distribute_dataset(self, dataset):
# Note that this might be smaller than one if model replicas are sharded
# across multiple processes.
mesh_batch_dim_index = self.device_mesh.axis_names.index(
self._batch_dim_name
self.batch_dim_name
)
num_model_replicas = self.device_mesh.shape[mesh_batch_dim_index]
if num_model_replicas == 1:
Expand Down
6 changes: 3 additions & 3 deletions keras/src/distribution/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_create_with_device_mesh(self):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["data"])
self.assertEqual(distribution._batch_dim_name, "data")
self.assertEqual(distribution.batch_dim_name, "data")

self.assertFalse(distribution._is_multi_process)
self.assertEqual(distribution._process_id, 0)
Expand All @@ -197,7 +197,7 @@ def test_create_with_devices(self):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["batch"])
self.assertEqual(distribution._batch_dim_name, "batch")
self.assertEqual(distribution.batch_dim_name, "batch")

@mock.patch.object(
distribution_lib,
Expand All @@ -211,7 +211,7 @@ def test_create_with_list_devices(self, mock_list_devices):
device_mesh = distribution.device_mesh
self.assertEqual(len(device_mesh.devices), 8)
self.assertEqual(device_mesh.axis_names, ["batch"])
self.assertEqual(distribution._batch_dim_name, "batch")
self.assertEqual(distribution.batch_dim_name, "batch")

def test_get_data_layout(self):
distribution = distribution_lib.DataParallel(
Expand Down

0 comments on commit fbf0af7

Please sign in to comment.