-
Notifications
You must be signed in to change notification settings - Fork 2.9k
99 lines (94 loc) · 4.69 KB
/
cloud-tpu-ci-nightly.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# Cloud TPU CI
#
# This job currently runs once per day. We use self-hosted TPU runners, so we'd
# have to add more runners to run on every commit.
#
# This job's build matrix runs over several TPU architectures using both the
# latest released jaxlib on PyPi ("pypi_latest") and the latest nightly
# jaxlib.("nightly"). It also installs a matching libtpu, either the one pinned
# to the release for "pypi_latest", or the latest nightly.for "nightly". It
# always locally installs jax from github head (already checked out by the
# Github Actions environment).
name: CI - Cloud TPU (nightly)
on:
schedule:
- cron: "0 14 * * *" # daily at 7am PST
workflow_dispatch: # allows triggering the workflow run manually
# This should also be set to read-only in the project settings, but it's nice to
# document and enforce the permissions here.
permissions:
contents: read
jobs:
cloud-tpu-test:
strategy:
fail-fast: false # don't cancel all jobs on failure
matrix:
jaxlib-version: ["pypi_latest", "nightly", "nightly+oldest_supported_libtpu"]
tpu-type: ["v3-8", "v4-8", "v5e-4"]
name: "TPU test (jaxlib=${{ matrix.jaxlib-version }}, ${{ matrix.tpu-type }})"
env:
LIBTPU_OLDEST_VERSION_DATE: 20230912
ENABLE_PJRT_COMPATIBILITY: ${{ matrix.jaxlib-version == 'nightly+oldest_supported_libtpu' }}
runs-on: ["self-hosted", "tpu", "${{ matrix.tpu-type }}"]
timeout-minutes: 60
defaults:
run:
shell: bash -ex {0}
steps:
# https://opensource.google/documentation/reference/github/services#actions
# mandates using a specific commit for non-Google actions. We use
# https://github.com/sethvargo/ratchet to pin specific versions.
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Install JAX test requirements
run: |
pip install -U -r build/test-requirements.txt
pip install -U -r build/collect-profile-requirements.txt
- name: Install JAX
run: |
pip uninstall -y jax jaxlib libtpu-nightly
if [ "${{ matrix.jaxlib-version }}" == "pypi_latest" ]; then
pip install .[tpu] \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
elif [ "${{ matrix.jaxlib-version }}" == "nightly" ]; then
pip install .
pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install --pre libtpu-nightly \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install requests
elif [ "${{ matrix.jaxlib-version }}" == "nightly+oldest_supported_libtpu" ]; then
pip install .
pip install --pre jaxlib \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
pip install --pre libtpu-nightly==0.1.dev${{ env.LIBTPU_OLDEST_VERSION_DATE }} \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install requests
else
echo "Unknown jaxlib-version: ${{ matrix.jaxlib-version }}"
exit 1
fi
python3 -c 'import sys; print("python version:", sys.version)'
python3 -c 'import jax; print("jax version:", jax.__version__)'
python3 -c 'import jaxlib; print("jaxlib version:", jaxlib.__version__)'
strings $HOME/.local/lib/python3.10/site-packages/libtpu/libtpu.so | grep 'Built on'
python3 -c 'import jax; print("libtpu version:",
jax.lib.xla_bridge.get_backend().platform_version)'
- name: Run tests
env:
JAX_PLATFORMS: tpu,cpu
PY_COLORS: 1
run: |
# Run single-accelerator tests in parallel
JAX_ENABLE_TPU_XDIST=true python3 -m pytest -n=4 --tb=short \
--maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
run: |
curl --location --request POST '${{ secrets.BUILD_CHAT_WEBHOOK }}' \
--header 'Content-Type: application/json' \
--data-raw "{
'text': '\"$GITHUB_WORKFLOW\", jaxlib/libtpu version \"${{ matrix.jaxlib-version }}\", TPU type ${{ matrix.tpu-type }} job failed, timed out, or was cancelled: $GITHUB_SERVER_URL/$GITHUB_REPOSITORY/actions/runs/$GITHUB_RUN_ID'
}"