Skip to content

Commit

Permalink
Add set_backend() utility.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Mar 11, 2024
1 parent b97338e commit 0d0be6a
Showing 1 changed file with 42 additions and 0 deletions.
42 changes: 42 additions & 0 deletions keras/utils/backend_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import copy
import importlib
import os
import sys

from keras import backend as backend_module
from keras.api_export import keras_export
from keras.backend.common import global_state


Expand Down Expand Up @@ -82,3 +86,41 @@ def __getattr__(self, name):
from keras import backend as numpy_backend

return getattr(numpy_backend, name)


@keras_export("keras.config.set_backend")
def set_backend(backend):
"""Reload the backend (and the Keras package).
Example:
```python
keras.config.set_backend("jax")
```
Note that this will **NOT** convert the type of any already
instantiated objects, except for the `keras` module itself.
Thus, any layers / tensors / etc. already created will no
longer be usable without errors. It is strongly recommended **not**
to keep around **any** Keras-originated objects instances created
before calling `set_backend()`.
"""
os.environ["KERAS_BACKEND"] = backend
# Clear module cache.
loaded_modules = [
key for key in sys.modules.keys() if key.startswith("keras")
]
for key in loaded_modules:
del sys.modules[key]
# Reimport Keras with the new backend (set via KERAS_BACKEND).
import keras

# Finally: refresh all imported Keras submodules.
globs = copy.copy(globals())
for key, value in globs.items():
if value.__class__ == keras.__class__:
if str(value).startswith("<module 'keras."):
module_name = str(value)
module_name = module_name[module_name.find("'") + 1 :]
module_name = module_name[: module_name.find("'")]
globals()[key] = importlib.import_module(module_name)

0 comments on commit 0d0be6a

Please sign in to comment.