Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make log spam test run and pass in Cloud TPU CI #19013

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .github/workflows/cloud-tpu-ci-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ jobs:
--maxfail=20 -m "not multiaccelerator" tests examples
# Run multi-accelerator across all chips
python3 -m pytest --tb=short --maxfail=20 -m "multiaccelerator" tests
# This test needs to be run by itself because it starts a subprocess
# that needs to access the TPU (only one process can use the TPU at a time)
python3 tests/logging_test.py LoggingTest.test_no_log_spam
- name: Send chat on failure
# Don't notify when testing the workflow from a branch.
if: ${{ (failure() || cancelled()) && github.ref_name == 'main' && matrix.jaxlib-version != 'nightly+oldest_supported_libtpu' }}
Expand Down
12 changes: 0 additions & 12 deletions jax/_src/cloud_tpu_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

import os
import warnings

running_in_cloud_tpu_vm: bool = False

Expand Down Expand Up @@ -66,14 +65,3 @@ def cloud_tpu_init() -> None:
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ.setdefault('JAX_PLATFORMS', 'tpu,cpu')
os.environ['TPU_ML_PLATFORM'] = 'JAX'

# TODO(skyewm): remove this warning at some point, say around Sept 2023.
use_pjrt_c_api = os.environ.get('JAX_USE_PJRT_C_API_ON_TPU', None)
if use_pjrt_c_api:
warnings.warn(
"JAX_USE_PJRT_C_API_ON_TPU no longer has an effect (the new TPU "
"runtime is always enabled now). Unset the environment variable "
"to disable this warning.")

# Remove when minimum jaxlib version is >= 0.4.15
os.environ['JAX_USE_PJRT_C_API_ON_TPU'] = "true"
5 changes: 4 additions & 1 deletion tests/logging_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import contextlib
import io
import logging
import os
import subprocess
import sys
import textwrap
Expand Down Expand Up @@ -69,8 +70,10 @@ def test_no_log_spam(self):
python = sys.executable
assert "python" in python
# Make sure C++ logging is at default level for the test process.
new_env = dict(os.environ)
new_env["TF_CPP_MIN_LOG_LEVEL"] = "1"
proc = subprocess.run([python, "-c", program], capture_output=True,
env={"TF_CPP_MIN_LOG_LEVEL": "1"})
env=new_env)

lines = proc.stdout.split(b"\n")
lines.extend(proc.stderr.split(b"\n"))
Expand Down
Loading