From 8df2766466add14025d13b5603898ea0943788b3 Mon Sep 17 00:00:00 2001 From: Ruturaj4 Date: Tue, 26 Nov 2024 16:22:24 -0600 Subject: [PATCH] Add argument to override base docker in dockerfile --- build/rocm/Dockerfile.ms | 3 ++- build/rocm/ci_build | 9 +++++++++ build/rocm/ci_build.sh | 7 ++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index e20291cefd63..575dce87664e 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -1,5 +1,6 @@ ################################################################################ -FROM ubuntu:20.04 AS rocm_base +ARG BASE_DOCKER=ubuntu:22.04 +FROM $BASE_DOCKER AS rocm_base ################################################################################ RUN --mount=type=cache,target=/var/cache/apt \ diff --git a/build/rocm/ci_build b/build/rocm/ci_build index 1ec5c6e7f36f..f3b8ae401649 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -143,6 +143,7 @@ def _fetch_jax_metadata(xla_path): def dist_docker( rocm_version, + base_docker, python_versions, xla_path, rocm_build_job="", @@ -168,6 +169,7 @@ def dist_docker( "--build-arg=ROCM_VERSION=%s" % rocm_version, "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, + "--build-arg=BASE_DOCKER=%s" % base_docker, "--build-arg=PYTHON_VERSION=%s" % python_version, "--build-arg=JAX_VERSION=%(jax_version)s" % md, "--build-arg=JAX_COMMIT=%(jax_commit)s" % md, @@ -231,6 +233,12 @@ def test(image_name): def parse_args(): p = argparse.ArgumentParser() + p.add_argument( + "--base-docker", + default="", + help="Argument to override base docker in dockerfile", + ) + p.add_argument( "--python-versions", type=lambda x: x.split(","), @@ -308,6 +316,7 @@ def main(): ) dist_docker( args.rocm_version, + args.base_docker, args.python_versions, args.xla_source_dir, rocm_build_job=args.rocm_build_job, diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 0a50b5845d69..386f70ee1a96 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -48,7 +48,7 @@ PYTHON_VERSION="3.10" ROCM_VERSION="6.1.3" ROCM_BUILD_JOB="" ROCM_BUILD_NUM="" -BASE_DOCKER="ubuntu:20.04" +BASE_DOCKER="ubuntu:22.04" CUSTOM_INSTALL="" JAX_USE_CLANG="" POSITIONAL_ARGS=() @@ -90,6 +90,10 @@ while [[ $# -gt 0 ]]; do ROCM_BUILD_NUM="$2" shift 2 ;; + --base_docker) + BASE_DOCKER="$2" + shift 2 + ;; --use_clang) JAX_USE_CLANG="$2" shift 2 @@ -154,6 +158,7 @@ fi # which is the ROCm image that is shipped for users to use (i.e. distributable). ./build/rocm/ci_build \ --rocm-version $ROCM_VERSION \ + --base-docker $BASE_DOCKER \ --python-versions $PYTHON_VERSION \ --xla-source-dir=$XLA_CLONE_DIR \ --rocm-build-job=$ROCM_BUILD_JOB \