Skip to content

Commit

Permalink
Add an experimental Cloud TPU presubmit job
Browse files Browse the repository at this point in the history
This adds an experimental non-blocking presubmit job that will run a subset of TPU tests, focusing on frequently failing tests. The goal is to achieve comprehensive coverage while keeping the runtime around 10 minutes.

PiperOrigin-RevId: 706064568
  • Loading branch information
nitins17 authored and Google-ML-Automation committed Dec 14, 2024
1 parent f4e5f14 commit d05ab5b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 13 deletions.
93 changes: 93 additions & 0 deletions .github/workflows/cloud-tpu-ci-presubmit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Cloud TPU CI (presubmit)
#
# This job currently runs as a non-blocking presubmit. It is experimental and is currently being
# tested to get to a stable state before we enable it as a blocking presubmit.
name: CI - Cloud TPU (presubmit)
on:
workflow_dispatch:
inputs:
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: choice
required: true
default: 'no'
options:
- 'yes'
- 'no'
pull_request:
branches:
- main

# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
permissions:
contents: read

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }}
cancel-in-progress: true

jobs:
cloud-tpu-test:
if: github.event.repository.fork == false
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
tpu: [
{type: "v5e-8", cores: "8", runner: "linux-x86-ct5lp-224-8tpu"}
]
python-version: ["3.10"]

name: "TPU test (jaxlib=head, ${{ matrix.tpu.type }})"

env:
JAXCI_PYTHON: python${{ matrix.python-version }}
JAXCI_TPU_CORES: ${{ matrix.tpu.cores }}

runs-on: ${{ matrix.tpu.runner }}
container: "us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest"

timeout-minutes: 60

defaults:
run:
shell: bash -ex {0}

steps:
# https://opensource.google/documentation/reference/github/services#actions
# mandates using a specific commit for non-Google actions. We use
# https://github.com/sethvargo/ratchet to pin specific versions.
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
# Checkout XLA at head, if we're building jaxlib at head.
- name: Checkout XLA at head
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
repository: openxla/xla
path: xla
# We need to mark the GitHub workspace as safe as otherwise git commands will fail.
- name: Mark GitHub workspace as safe
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
- name: Install JAX test requirements
run: |
$JAXCI_PYTHON -m pip install -U -r build/test-requirements.txt
$JAXCI_PYTHON -m pip install -U -r build/collect-profile-requirements.txt
- name: Build jaxlib at head with latest XLA
run: |
# Build and install jaxlib at head
$JAXCI_PYTHON build/build.py build --wheels=jaxlib \
--python_version=${{ matrix.python-version }} \
--bazel_options=--config=rbe_linux_x86_64 \
--local_xla_path="$(pwd)/xla" \
--verbose
# Install libtpu
$JAXCI_PYTHON -m pip install --pre libtpu \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@main
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Install jaxlib wheel and run tests
run: ./ci/run_pytest_tpu.sh
25 changes: 12 additions & 13 deletions ci/run_pytest_tpu.sh
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -33,29 +33,28 @@ source ./ci/utilities/install_wheels_locally.sh
# Set up the build environment.
source "ci/utilities/setup_build_environment.sh"

export PY_COLORS=1
export JAX_SKIP_SLOW_TESTS=true

"$JAXCI_PYTHON" -c "import jax; print(jax.default_backend()); print(jax.devices()); print(len(jax.devices()))"

"$JAXCI_PYTHON" -c 'import sys; print("python version:", sys.version)'
"$JAXCI_PYTHON" -c 'import jax; print("jax version:", jax.__version__)'
"$JAXCI_PYTHON" -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
strings /usr/local/lib/"$JAXCI_PYTHON"/site-packages/libtpu/libtpu.so | grep 'Built on'
strings /usr/local/lib/"$JAXCI_PYTHON"/dist-packages/libtpu/libtpu.so | grep 'Built on'
"$JAXCI_PYTHON" -c 'import jax; print("libtpu version:",jax.lib.xla_bridge.get_backend().platform_version)'

echo "Running TPU tests..."
# Set up all common test environment variables
export PY_COLORS=1
export JAX_PLATFORMS=tpu,cpu
# Run single-accelerator tests in parallel
export JAX_ENABLE_TPU_XDIST=true
export JAX_SKIP_SLOW_TESTS=true
# End of common test environment variable setup

"$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
echo "Running TPU tests..."

# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true "$JAXCI_PYTHON" -m pytest -n="$JAXCI_TPU_CORES" --tb=short \
--deselect=tests/pallas/tpu_pallas_test.py::PallasCallPrintTest \
--maxfail=20 -m "not multiaccelerator" tests examples
--maxfail=20 -m "not multiaccelerator" tests/pallas/tpu_ops_test.py

# Run Pallas printing tests, which need to run with I/O capturing disabled.
export TPU_STDERR_LOG_LEVEL=0
"$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest
TPU_STDERR_LOG_LEVEL=0 "$JAXCI_PYTHON" -m pytest -s tests/pallas/tpu_pallas_test.py::PallasCallPrintTest

# Run multi-accelerator across all chips
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
"$JAXCI_PYTHON" -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests/pjit_test.py

0 comments on commit d05ab5b

Please sign in to comment.