Skip to content

Commit

Permalink
[Pallas TPU] Add helpers file with copy_ref function
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 716030813
  • Loading branch information
sharadmv authored and Google-ML-Automation committed Jan 16, 2025
1 parent 4a9cc9f commit 0ac6315
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 0 deletions.
1 change: 1 addition & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,7 @@ pytype_strict_library(
":tpu_custom_call",
"//jax/_src/pallas",
"//jax/_src/pallas/mosaic:core",
"//jax/_src/pallas/mosaic:helpers",
"//jax/_src/pallas/mosaic:lowering",
"//jax/_src/pallas/mosaic:pallas_call_registration", # build_cleaner: keep
"//jax/_src/pallas/mosaic:pipeline",
Expand Down
11 changes: 11 additions & 0 deletions jax/_src/pallas/mosaic/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,14 @@ py_library(
"//jax:typing",
] + py_deps("numpy"),
)

py_library(
name = "helpers",
srcs = ["helpers.py"],
deps = [
":core",
":primitives",
"//jax",
"//jax/_src/pallas",
],
)
57 changes: 57 additions & 0 deletions jax/_src/pallas/mosaic/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpers for Pallas TPU kernels."""

import functools
import jax
from jax._src.pallas import primitives as pl_primitives
from jax._src.pallas.mosaic import core as tpu_core
from jax._src.pallas.mosaic import primitives as plm_primitives


def sync_copy(src_ref, dst_ref):
"""Copies a PyTree of Refs to another PyTree of Refs.
Args:
src_ref: A Pytree of source Refs/TransformedRefs.
dst_ref: A Pytree of destination Refs/TransformedRefs.
"""
if not jax.tree.leaves(src_ref):
# No buffers to copy so skip the function.
return

@functools.partial(
pl_primitives.run_scoped, sem=tpu_core.SemaphoreType.DMA(())
)
def _(sem):
def _copy_start_or_wait(action, src_ref, dst_ref):
descriptor = plm_primitives.make_async_copy(src_ref, dst_ref, sem)
if action == "start":
descriptor.start()
elif action == "wait":
descriptor.wait()
else:
raise ValueError(f"Unknown action: {action}")

jax.tree.map(
functools.partial(_copy_start_or_wait, "start"),
src_ref,
dst_ref,
)
jax.tree.map(
functools.partial(_copy_start_or_wait, "wait"),
src_ref,
dst_ref,
)
1 change: 1 addition & 0 deletions jax/experimental/pallas/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax._src.pallas.mosaic.core import TPUCompilerParams as TPUCompilerParams
from jax._src.pallas.mosaic.core import runtime_assert_enabled as runtime_assert_enabled
from jax._src.pallas.mosaic.core import _ENABLE_RUNTIME_ASSERT as enable_runtime_assert # noqa: F401
from jax._src.pallas.mosaic.helpers import sync_copy as sync_copy
from jax._src.pallas.mosaic.lowering import LoweringException as LoweringException
from jax._src.pallas.mosaic.pipeline import ARBITRARY as ARBITRARY
from jax._src.pallas.mosaic.pipeline import BufferedRef as BufferedRef
Expand Down

0 comments on commit 0ac6315

Please sign in to comment.