diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 0b1dc082765e..add297b6a351 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -691,13 +691,9 @@ def check_compatible_aval(self, aval_shape: Shape) -> None: def _remake( cls, devices: tuple[xc.Device, ...], ids: np.ndarray, *, memory_kind: str | None = None) -> PositionalSharding: - self = cls.__new__(cls) - self._devices = devices - self._ids = ids - self._internal_device_list = xc.DeviceList(self._devices) - self._memory_kind = xc.check_and_canonicalize_memory_kind( - memory_kind, self._internal_device_list) - return self + sharding = cls(devices, memory_kind=memory_kind) + sharding._ids = ids + return sharding # Hashable