diff --git a/jax/_src/lib/mosaic_gpu.py b/jax/_src/lib/mosaic_gpu.py index 233b51db4d6f..137b132778b9 100644 --- a/jax/_src/lib/mosaic_gpu.py +++ b/jax/_src/lib/mosaic_gpu.py @@ -17,12 +17,9 @@ try: try: from jaxlib.mosaic.gpu import _mosaic_gpu_ext # pytype: disable=import-error - except ImportError as e: - print(e) + except ImportError: from jax_cuda12_plugin import _mosaic_gpu_ext # pytype: disable=import-error -except ImportError as e: - print("="*128) - print(e) +except ImportError: raise ModuleNotFoundError("Failed to import the Mosaic GPU bindings") from e else: _mosaic_gpu_ext.register_passes()