Skip to content

Commit

Permalink
[AutoPGLE] Use compile options to override debug options instead of X…
Browse files Browse the repository at this point in the history
…LA_FLAGS.

PiperOrigin-RevId: 697924164
  • Loading branch information
Google-ML-Automation committed Nov 19, 2024
1 parent d397dd9 commit da50ad7
Showing 1 changed file with 153 additions and 174 deletions.
327 changes: 153 additions & 174 deletions tests/pgle_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import ExitStack
from functools import partial
import glob
import logging
Expand Down Expand Up @@ -43,54 +42,14 @@


@jtu.pytest_mark_if_available('multiaccelerator')
# TODO(patrios): Remove this skip once b/379267258 is fixed.
@jtu.skip_under_pytest(
'This test requires specific XLA_FLAGS. However, pytest does not reload '
'modules between tests. So if another test is launched before this one '
'necessary XLA_FLAGS will not be re-used by the XLA.')
class PgleTest(jtu.JaxTestCase):
_dump_exit_stack: ExitStack | None = None

@classmethod
def setUpClass(cls):
super().setUpClass()
cls._dump_exit_stack = ExitStack()

cls.dump_dir = cls._dump_exit_stack.enter_context(tempfile.TemporaryDirectory())
if 'XLA_FLAGS' in os.environ:
cls.old_xla_flags = os.environ['XLA_FLAGS']
else:
cls.old_xla_flags = None

os.environ['XLA_FLAGS'] = (
f'--xla_dump_to={cls.dump_dir}'
' --xla_gpu_experimental_dump_fdo_profiles=true'
' --xla_gpu_enable_latency_hiding_scheduler=true'
# TODO(patrios): Remove this flag once b/376647494 is fixed.
' --xla_gpu_graph_level=0'
)
if cls.old_xla_flags:
os.environ['XLA_FLAGS'] += ' ' + cls.old_xla_flags

@classmethod
def tearDownClass(cls):
if cls.old_xla_flags:
os.environ['XLA_FLAGS'] = cls.old_xla_flags
cls._dump_exit_stack.close()
super().tearDownClass()

def setUp(self):
super().setUp()
cc.set_cache_dir(None)
cc.reset_cache()

def tearDown(self):
# Cleanup dump directory
for file in os.listdir(self.dump_dir):
file_path = os.path.join(self.dump_dir, file)
if os.path.isfile(file_path):
os.remove(file_path)

cc.set_cache_dir(None)
super().tearDown()

Expand All @@ -101,6 +60,7 @@ def testPGLEProfilerGetFDOProfile(self):
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'},
)
def f(x, y):
return x @ y
Expand Down Expand Up @@ -130,6 +90,11 @@ def testPGLEProfilerGetFDOProfileLarge(self):
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
},
)
def f(x):
agg = x
Expand All @@ -154,6 +119,11 @@ def testAutoPgle(self):
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
},
)
def f(x):
return x * 2
Expand All @@ -172,7 +142,7 @@ def f(x):
# Run 2: Second PGLE run should not recompile the module
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 0)
self.assertLess(cache_miss_count[0], 2)

# Run 3: The module should be recompiled with FDO profiles
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
Expand All @@ -182,7 +152,7 @@ def f(x):
# Run 4: Fast-path should be used after PGLE is done
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
self.assertArraysEqual(f(x), expected)
self.assertEqual(cache_miss_count[0], 0)
self.assertLess(cache_miss_count[0], 2)

def testAutoPgleWithAot(self):
@jax.jit
Expand Down Expand Up @@ -211,145 +181,154 @@ def testAutoPgleWithPersistentCache(self):
its = 50
mesh = jtu.create_mesh((2,), ('x',))

@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
)
def f(x):
agg = x
for _ in range(its):
agg = agg @ x
return agg

shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)

with (config.enable_compilation_cache(True),
config.enable_pgle(True),
config.raise_persistent_cache_errors(True),
config.raise_persistent_cache_errors(True),
config.persistent_cache_min_entry_size_bytes(0),
config.persistent_cache_min_compile_time_secs(0),
config.pgle_profiling_runs(2),
tempfile.TemporaryDirectory() as cache_dir):
cc.reset_cache()
cc.set_cache_dir(cache_dir)
# Run 1: Module should be compiled without FDO
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
with tempfile.TemporaryDirectory() as dump_dir:
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={
'xla_gpu_enable_latency_hiding_scheduler': 'True',
# TODO(patrios): Remove this flag once b/376647494 is fixed.
'xla_gpu_graph_min_graph_size': '100000',
'xla_dump_to': dump_dir,
'xla_gpu_experimental_dump_fdo_profiles': 'True'
},
)
def f(x):
agg = x
for _ in range(its):
agg = agg @ x
return agg

shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)

with (config.enable_compilation_cache(True),
config.enable_pgle(True),
config.raise_persistent_cache_errors(True),
config.raise_persistent_cache_errors(True),
config.persistent_cache_min_entry_size_bytes(0),
config.persistent_cache_min_compile_time_secs(0),
config.pgle_profiling_runs(2),
tempfile.TemporaryDirectory() as cache_dir):
cc.reset_cache()
cc.set_cache_dir(cache_dir)
# Run 1: Module should be compiled without FDO
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
f(x)
self.assertGreater(cache_miss_count[0], 0)

# Non-pgle profiled version of module should be saved
non_pgle_profiled_files = os.listdir(cache_dir)
self.assertNotEmpty(non_pgle_profiled_files)

# Run 2: Compilation should not be called
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
f(x)
self.assertLess(cache_miss_count[0], 2)

module_before_pgle = os.listdir(dump_dir)
self.assertNotEmpty(module_before_pgle)
# Run 3: Module should be compiled with FDO and stored to persistent cache
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
f(x)
self.assertGreater(cache_miss_count[0], 0)

# Check if FDO profile file of the biggest module is not empty
module_after_pgle = [
x
for x in os.listdir(dump_dir)
if x not in module_before_pgle
]
self.assertNotEmpty(module_after_pgle)
biggest_module_after_pgle = max(
module_after_pgle,
key=lambda x: os.path.getsize(
os.path.join(dump_dir, x)
),
)
base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1])

# Check if FDO profile file in dump directory is not empty
for module in module_after_pgle:
if module.startswith(base_module_name) and module.endswith(
'.fdo_profile'
):
self.assertGreater(
os.path.getsize(os.path.join(dump_dir, module)), 0
)

for pgle_profiler in pjit._pgle_profiler_dict.values():
self.assertTrue(pgle_profiler.is_enabled())
self.assertTrue(pgle_profiler.is_fdo_consumed())

files_after_pgle_profile = os.listdir(cache_dir)
self.assertGreater(
len(files_after_pgle_profile), len(non_pgle_profiled_files)
)

# Removing non-pgle profiled module from cache to check that later pgle
# profiled version will be used.
for non_pgle_file in non_pgle_profiled_files:
path = os.path.join(cache_dir, non_pgle_file)
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)

api.clear_caches()
pjit._pgle_profiler_dict.clear()

# Run 4: Persistent compilation cache should be hit PGLE profiler should
# be disabled
cache_hit = 0
def check_if_cache_hit(event):
nonlocal cache_hit
if event == '/jax/compilation_cache/cache_hits':
cache_hit += 1

monitoring.register_event_listener(check_if_cache_hit)
f(x)
self.assertGreater(cache_miss_count[0], 0)
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)

# Non-pgle profiled version of module should be saved
non_pgle_profiled_files = os.listdir(cache_dir)
self.assertNotEmpty(non_pgle_profiled_files)
self.assertGreater(cache_hit, 0)

# Run 2: Compilation should not be called
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
f(x)
self.assertEqual(cache_miss_count[0], 0)
def testPassingFDOProfile(self):
mesh = jtu.create_mesh((2,), ('x',))

module_before_pgle = os.listdir(self.dump_dir)
self.assertNotEmpty(module_before_pgle)
# Run 3: Module should be compiled with FDO and stored to persistent cache
with jtu.count_cached_compilation_cache_miss() as cache_miss_count:
f(x)
self.assertGreater(cache_miss_count[0], 0)

# Check if FDO profile file of the biggest module is not empty
module_after_pgle = [
x
for x in os.listdir(self.dump_dir)
if x not in module_before_pgle
]
self.assertNotEmpty(module_after_pgle)
biggest_module_after_pgle = max(
module_after_pgle,
key=lambda x: os.path.getsize(
os.path.join(self.dump_dir, x)
),
)
base_module_name = '.'.join(biggest_module_after_pgle.split('.')[0:1])

# Check if FDO profile file in dump directory is not empty
for module in module_after_pgle:
if module.startswith(base_module_name) and module.endswith(
'.fdo_profile'
):
self.assertGreater(
os.path.getsize(os.path.join(self.dump_dir, module)), 0
)

for pgle_profiler in pjit._pgle_profiler_dict.values():
self.assertTrue(pgle_profiler.is_enabled())
self.assertTrue(pgle_profiler.is_fdo_consumed())

files_after_pgle_profile = os.listdir(cache_dir)
self.assertGreater(
len(files_after_pgle_profile), len(non_pgle_profiled_files)
@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
compiler_options={'xla_gpu_enable_latency_hiding_scheduler': 'True'},
)
def f(x, y):
return x @ y

# Removing non-pgle profiled module from cache to check that later pgle
# profiled version will be used.
for non_pgle_file in non_pgle_profiled_files:
path = os.path.join(cache_dir, non_pgle_file)
if os.path.isfile(path):
os.remove(path)
elif os.path.isdir(path):
shutil.rmtree(path)

api.clear_caches()
pjit._pgle_profiler_dict.clear()

# Run 4: Persistent compilation cache should be hit PGLE profiler should
# be disabled
cache_hit = 0
def check_if_cache_hit(event):
nonlocal cache_hit
if event == '/jax/compilation_cache/cache_hits':
cache_hit += 1

monitoring.register_event_listener(check_if_cache_hit)
f(x)
monitoring._unregister_event_listener_by_callback(check_if_cache_hit)

self.assertGreater(cache_hit, 0)

def testPassingFDOProfile(self):
mesh = jtu.create_mesh((2,), ('x',))

@partial(
jax.jit,
in_shardings=NamedSharding(mesh, PartitionSpec('x')),
out_shardings=NamedSharding(mesh, PartitionSpec('x')),
)
def f(x, y):
return x @ y

shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
y = x + 1
shape = (16, 16)
x = jnp.arange(math.prod(shape)).reshape(shape).astype(np.float32)
y = x + 1

with config.pgle_profiling_runs(0):
f_lowered = f.lower(x, y)
compiled = f_lowered.compile()
with config.pgle_profiling_runs(0):
f_lowered = f.lower(x, y)
compiled = f_lowered.compile()

with tempfile.TemporaryDirectory() as cache_dir:
jax.profiler.start_trace(cache_dir)
compiled(x, y)
jax.profiler.stop_trace()
directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/'))
directories = [d for d in directories if os.path.isdir(d)]
rundir = directories[-1]
logging.info('rundir: %s', rundir)
fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir)

if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda():
self.assertIn(b'custom', fdo_profile)

logging.info('fdo_profile: %s', fdo_profile)
# Test pass fdo_profile as compiler_options API works.
f_lowered.compile(compiler_options={'fdo_profile': fdo_profile})
with tempfile.TemporaryDirectory() as cache_dir:
jax.profiler.start_trace(cache_dir)
compiled(x, y)
jax.profiler.stop_trace()
directories = glob.glob(os.path.join(cache_dir, 'plugins/profile/**/'))
directories = [d for d in directories if os.path.isdir(d)]
rundir = directories[-1]
logging.info('rundir: %s', rundir)
fdo_profile = exp_profiler.get_profiled_instructions_proto(rundir)

if jtu.test_device_matches(['gpu']) and jtu.is_device_cuda():
self.assertIn(b'custom', fdo_profile)

logging.info('fdo_profile: %s', fdo_profile)
# Test pass fdo_profile as compiler_options API works.
f_lowered.compile(compiler_options={'fdo_profile': fdo_profile})


if __name__ == '__main__':
Expand Down

0 comments on commit da50ad7

Please sign in to comment.