Skip to content

Commit

Permalink
Make log spam test run and pass in Cloud TPU CI
Browse files Browse the repository at this point in the history
  • Loading branch information
skye committed Dec 16, 2023
1 parent 732fd15 commit e5dbca5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 12 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ jobs:
--maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
# This test needs to be run by itself because it starts a subprocess
# that needs to access the TPU (only one process can use the TPU at a time)
python3 test/logging_test.py LoggingTest.test_no_log_spam
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
Expand Down
11 changes: 0 additions & 11 deletions jax/_src/cloud_tpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,3 @@ def cloud_tpu_init() -> None:
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
os.environ['TPU_ML_PLATFORM'] = 'JAX'

# TODO(skyewm): remove this warning at some point, say around Sept 2023.
use_pjrt_c_api = os.environ.get('JAX_USE_PJRT_C_API_ON_TPU', None)
if use_pjrt_c_api:
warnings.warn(
"JAX_USE_PJRT_C_API_ON_TPU no longer has an effect (the new TPU "
"runtime is always enabled now). Unset the environment variable "
"to disable this warning.")

# Remove when minimum jaxlib version is >= 0.4.15
os.environ['JAX_USE_PJRT_C_API_ON_TPU'] = "true"
5 changes: 4 additions & 1 deletion tests/logging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import contextlib
import io
import logging
import os
import subprocess
import sys
import textwrap
Expand Down Expand Up @@ -69,8 +70,10 @@ def test_no_log_spam(self):
python = sys.executable
assert "python" in python
# Make sure C++ logging is at default level for the test process.
new_env = dict(os.environ)
new_env["TF_CPP_MIN_LOG_LEVEL"] = "1"
proc = subprocess.run([python, "-c", program], capture_output=True,
env={"TF_CPP_MIN_LOG_LEVEL": "1"})
env=new_env)

lines = proc.stdout.split(b"\n")
lines.extend(proc.stderr.split(b"\n"))
Expand Down

0 comments on commit e5dbca5

Please sign in to comment.