From 5315cd8532dce315aebe693943f7beecef6e497a Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 24 Sep 2024 14:11:20 -0700 Subject: [PATCH 1/2] Set TF_FORCE_GPU_ALLOW_GROWTH=true by default This is needed to be able to run Fuji v2 70B on GPU without GPU memory OOMs. --- axlearn/common/launch.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/axlearn/common/launch.py b/axlearn/common/launch.py index 843454aa..e78ad2ea 100644 --- a/axlearn/common/launch.py +++ b/axlearn/common/launch.py @@ -41,6 +41,10 @@ # Note: this will disable other TF_CPP info and warnnings. os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") +# Prevent GPU OOM issues due to TF taking up all the GPU memory. +# Reference: https://stackoverflow.com/a/54927279 +os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true") + # Import jax before tensorflow else to avoid problems such as: # tpu_library_init_fns.inc:98] TpuEmbeddingEngine_ExecutePartitioner not available in this library. import jax # jax must be imported before tensorflow! From 384bef4de20baafbead16f68291dd7598246105a Mon Sep 17 00:00:00 2001 From: Sam Stoelinga Date: Tue, 24 Sep 2024 14:39:46 -0700 Subject: [PATCH 2/2] only set gpu env variables if instance type is gpu --- axlearn/common/launch.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/axlearn/common/launch.py b/axlearn/common/launch.py index e78ad2ea..b9a6cf46 100644 --- a/axlearn/common/launch.py +++ b/axlearn/common/launch.py @@ -41,9 +41,10 @@ # Note: this will disable other TF_CPP info and warnnings. os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "2") -# Prevent GPU OOM issues due to TF taking up all the GPU memory. -# Reference: https://stackoverflow.com/a/54927279 -os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true") +if instance_type.startswith("gpu"): + # Prevent GPU OOM issues due to TF taking up all the GPU memory. + # Reference: https://stackoverflow.com/a/54927279 + os.environ.setdefault("TF_FORCE_GPU_ALLOW_GROWTH", "true") # Import jax before tensorflow else to avoid problems such as: # tpu_library_init_fns.inc:98] TpuEmbeddingEngine_ExecutePartitioner not available in this library.