diff --git a/.github/workflows/asan.yaml b/.github/workflows/asan.yaml new file mode 100644 index 000000000000..fd3054aabbef --- /dev/null +++ b/.github/workflows/asan.yaml @@ -0,0 +1,80 @@ +name: CI - Address Sanitizer (nightly) + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +on: + schedule: + - cron: "0 12 * * *" # Daily at 12:00 UTC + workflow_dispatch: # allows triggering the workflow run manually + pull_request: # Automatically trigger on pull requests affecting this file + branches: + - main + paths: + - '**workflows/asan.yml' + +jobs: + upstream-dev: + runs-on: ubuntu-20.04-16core + strategy: + fail-fast: false + matrix: + python-version: ["3.12"] + steps: + - uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 + with: + path: jax + - uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 + with: + repository: python/cpython + path: cpython + ref: v3.12.6 + - name: Install clang 18 + run: | + wget -qO- https://apt.llvm.org/llvm-snapshot.gpg.key | sudo tee /etc/apt/trusted.gpg.d/apt.llvm.org.asc + echo deb http://apt.llvm.org/focal/ llvm-toolchain-focal-18 main | sudo tee -a /etc/apt/sources.list + sudo apt update + sudo apt install clang-18 libstdc++-10-dev + - name: Build CPython with ASAN enabled + env: + ASAN_OPTIONS: detect_leaks=0 + run: | + cd cpython + mkdir ${GITHUB_WORKSPACE}/cpythonasan + CC=clang-18 ./configure --prefix ${GITHUB_WORKSPACE}/cpythonasan --with-address-sanitizer --without-pymalloc + make -j16 + make install + ${GITHUB_WORKSPACE}/cpythonasan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv + - name: Install JAX test requirements + env: + ASAN_OPTIONS: detect_leaks=0 + run: | + source ${GITHUB_WORKSPACE}/venv/bin/activate + cd jax + pip install -r build/test-requirements.txt + - name: Build and install JAX + env: + ASAN_OPTIONS: detect_leaks=0 + run: | + source ${GITHUB_WORKSPACE}/venv/bin/activate + cd jax + python build/build.py --bazel_options=--copt=-fsanitize=address --clang_path=/usr/bin/clang-18 + pip install dist/jaxlib-*.whl + pip install -e . + - name: Run tests + env: + ASAN_OPTIONS: detect_leaks=0 + JAX_NUM_GENERATED_CASES: 1 + JAX_ENABLE_X64: true + JAX_SKIP_SLOW_TESTS: true + PY_COLORS: 1 + run: | + source ${GITHUB_WORKSPACE}/venv/bin/activate + cd jax + echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES" + echo "JAX_ENABLE_X64=$JAX_ENABLE_X64" + echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS" + # The LD_PRELOAD works around https://github.com/google/sanitizers/issues/934#issuecomment-649516500 + LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 python -m pytest -n auto --tb=short --maxfail=20 tests + \ No newline at end of file