Skip to content

Commit

Permalink
Merge pull request jax-ml#25136 from ROCm:ci_dockerfile_arg_changes-u…
Browse files Browse the repository at this point in the history
…pstream

PiperOrigin-RevId: 701959495
  • Loading branch information
Google-ML-Automation committed Dec 2, 2024
2 parents bd66f52 + 8df2766 commit 7b32d88
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 2 deletions.
3 changes: 2 additions & 1 deletion build/rocm/Dockerfile.ms
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
################################################################################
FROM ubuntu:20.04 AS rocm_base
ARG BASE_DOCKER=ubuntu:22.04
FROM $BASE_DOCKER AS rocm_base
################################################################################

RUN --mount=type=cache,target=/var/cache/apt \
Expand Down
9 changes: 9 additions & 0 deletions build/rocm/ci_build
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def _fetch_jax_metadata(xla_path):

def dist_docker(
rocm_version,
base_docker,
python_versions,
xla_path,
rocm_build_job="",
Expand All @@ -168,6 +169,7 @@ def dist_docker(
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--build-arg=BASE_DOCKER=%s" % base_docker,
"--build-arg=PYTHON_VERSION=%s" % python_version,
"--build-arg=JAX_VERSION=%(jax_version)s" % md,
"--build-arg=JAX_COMMIT=%(jax_commit)s" % md,
Expand Down Expand Up @@ -231,6 +233,12 @@ def test(image_name):

def parse_args():
p = argparse.ArgumentParser()
p.add_argument(
"--base-docker",
default="",
help="Argument to override base docker in dockerfile",
)

p.add_argument(
"--python-versions",
type=lambda x: x.split(","),
Expand Down Expand Up @@ -308,6 +316,7 @@ def main():
)
dist_docker(
args.rocm_version,
args.base_docker,
args.python_versions,
args.xla_source_dir,
rocm_build_job=args.rocm_build_job,
Expand Down
7 changes: 6 additions & 1 deletion build/rocm/ci_build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ PYTHON_VERSION="3.10"
ROCM_VERSION="6.1.3"
ROCM_BUILD_JOB=""
ROCM_BUILD_NUM=""
BASE_DOCKER="ubuntu:20.04"
BASE_DOCKER="ubuntu:22.04"
CUSTOM_INSTALL=""
JAX_USE_CLANG=""
POSITIONAL_ARGS=()
Expand Down Expand Up @@ -90,6 +90,10 @@ while [[ $# -gt 0 ]]; do
ROCM_BUILD_NUM="$2"
shift 2
;;
--base_docker)
BASE_DOCKER="$2"
shift 2
;;
--use_clang)
JAX_USE_CLANG="$2"
shift 2
Expand Down Expand Up @@ -154,6 +158,7 @@ fi
# which is the ROCm image that is shipped for users to use (i.e. distributable).
./build/rocm/ci_build \
--rocm-version $ROCM_VERSION \
--base-docker $BASE_DOCKER \
--python-versions $PYTHON_VERSION \
--xla-source-dir=$XLA_CLONE_DIR \
--rocm-build-job=$ROCM_BUILD_JOB \
Expand Down

0 comments on commit 7b32d88

Please sign in to comment.