diff --git a/foolbox/models/tensorflow.py b/foolbox/models/tensorflow.py index 394f9e9b..c1a4743c 100644 --- a/foolbox/models/tensorflow.py +++ b/foolbox/models/tensorflow.py @@ -10,7 +10,7 @@ def get_device(device: Any) -> Any: import tensorflow as tf if device is None: - device = tf.device("/GPU:0" if tf.test.is_gpu_available() else "/CPU:0") + device = tf.device("/GPU:0" if tf.config.list_physical_devices("GPU") else "/CPU:0") if isinstance(device, str): device = tf.device(device) return device diff --git a/tests/conftest.py b/tests/conftest.py index a258fa3e..8303833b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -266,7 +266,7 @@ def tensorflow_resnet50(request: Any) -> ModelAndData: import tensorflow as tf - if not tf.test.is_gpu_available(): + if not list_physical_devices("GPU"): pytest.skip("ResNet50 test too slow without GPU") model = tf.keras.applications.ResNet50(weights="imagenet")