Skip to content

Commit

Permalink
Use bazel to run tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed Dec 17, 2024
1 parent c8a101d commit 52ce2c1
Show file tree
Hide file tree
Showing 5 changed files with 634 additions and 27 deletions.
48 changes: 23 additions & 25 deletions .github/workflows/tsan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ jobs:
with:
repository: python/cpython
path: cpython
ref: v3.13.0
ref: v3.13.1
- name: Build CPython with TSAN enabled
run: |
cd cpython
Expand All @@ -54,27 +54,27 @@ jobs:
# Check whether free-threading mode is enabled
PYTHON_GIL=0 ${GITHUB_WORKSPACE}/cpython-tsan/bin/python3 -c "import sys; assert not sys._is_gil_enabled()"
${GITHUB_WORKSPACE}/cpython-tsan/bin/python3 -m venv ${GITHUB_WORKSPACE}/venv
- name: Install JAX test requirements
run: |
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
python -m pip install -r build/test-requirements.txt
# Create archive to be used with bazel as hermetic python:
cd ${GITHUB_WORKSPACE} && tar -czpf python-tsan.tgz cpython-tsan
- name: Build and install JAX
run: |
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
export PYTHON_SHA256=($(sha256sum ${GITHUB_WORKSPACE}/python-tsan.tgz))
echo "Python sha256: ${PYTHON_SHA256}"
python build/build.py build --wheels=jaxlib \
--bazel_options=--repo_env=HERMETIC_PYTHON_VERSION=3.13-ft \
--bazel_options=--repo_env=HERMETIC_PYTHON_URL="file://${GITHUB_WORKSPACE}/python-tsan.tgz" \
--bazel_options=--repo_env=HERMETIC_PYTHON_SHA256=${PYTHON_SHA256} \
--bazel_options=--repo_env=HERMETIC_PYTHON_PREFIX="cpython-tsan/" \
--bazel_options=--color=yes \
--bazel_options=--copt=-fsanitize=thread \
--bazel_options=--linkopt="-fsanitize=thread" \
--bazel_options=--@rules_python//python/config_settings:py_freethreaded="yes" \
--bazel_options=--@nanobind//:enabled_free_threading=True \
--clang_path=/usr/bin/clang-18
# We have to manually install nightly scipy, otherwise default scipy installation
# is failing to build it here: ../meson.build:84:0: ERROR: Unknown compiler(s)
python -m pip install -i https://pypi.anaconda.org/scientific-python-nightly-wheels/simple scipy
python -m pip install dist/jaxlib-*.whl
python -m pip install -e .
- name: Run tests
timeout-minutes: 30
env:
Expand All @@ -83,28 +83,26 @@ jobs:
JAX_SKIP_SLOW_TESTS: true
PY_COLORS: 1
run: |
source ${GITHUB_WORKSPACE}/venv/bin/activate
cd jax
echo "JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES"
echo "JAX_ENABLE_X64=$JAX_ENABLE_X64"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
# As we do not have yet free-threading support
# there will be the following warning:
# RuntimeWarning: The global interpreter lock (GIL) has been enabled to load module 'jaxlib.utils',
# which has not declared that it can run safely without the GIL.
# To avoid that we temporarily define PYTHON_GIL
export PYTHON_GIL=0
# Continue running all commands even if they failing
set +e
python -m pytest -s -vvv tests/jaxpr_effects_test.py::EffectOrderingTest::test_different_threads_get_different_tokens
exit_code=$?
python -m pytest -s -vvv tests/api_test.py::CustomJVPTest::test_concurrent_initial_style
exit_code=$(( $exit_code | $? ))
python -m pytest -s -vvv tests/api_test.py::APITest::test_concurrent_device_get_and_put
exit_code=$(( $exit_code | $? ))
python -m pytest -s -vvv tests/api_test.py::JitTest::test_concurrent_jit
exit_code=$(( $exit_code | $? ))
bazel_exec=$(ls build/bazel-*)
ln -s ${bazel_exec} bazel
exit $exit_code
./bazel test \
--repo_env=HERMETIC_PYTHON_VERSION=3.13-ft \
--repo_env=JAX_NUM_GENERATED_CASES=$JAX_NUM_GENERATED_CASES \
--repo_env=JAX_ENABLE_X64=$JAX_ENABLE_X64 \
--repo_env=JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS \
--repo_env=PYTHON_GIL=$PYTHON_GIL \
--//jax:build_jaxlib=false \
//tests:cpu_tests
14 changes: 12 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,17 @@ load("//third_party/xla:workspace.bzl", jax_xla_workspace = "repo")
jax_xla_workspace()

# Initialize hermetic Python
load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")
python_init_rules()
# load("@xla//third_party/py:python_init_rules.bzl", "python_init_rules")
# python_init_rules()
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "rules_python",
sha256 = "62ddebb766b4d6ddf1712f753dac5740bea072646f630eb9982caa09ad8a7687",
strip_prefix = "rules_python-0.39.0",
url = "https://github.com/bazelbuild/rules_python/releases/download/0.39.0/rules_python-0.39.0.tar.gz",
patch_args = ["-p1"],
patches = ["//third_party/rules_python:rules_python.patch"],
)

load("@xla//third_party/py:python_init_repositories.bzl", "python_init_repositories")
python_init_repositories(
Expand All @@ -13,6 +22,7 @@ python_init_repositories(
"3.11": "//build:requirements_lock_3_11.txt",
"3.12": "//build:requirements_lock_3_12.txt",
"3.13": "//build:requirements_lock_3_13.txt",
"3.13-ft": "//build:requirements_lock_3_13_ft.txt",
},
local_wheel_inclusion_list = [
"jaxlib*",
Expand Down
Loading

0 comments on commit 52ce2c1

Please sign in to comment.