diff --git a/jax/experimental/pallas/ops/gpu/attention_mgpu.py b/jax/experimental/pallas/ops/gpu/attention_mgpu.py index 6f02396ccb92..b684aef409f1 100644 --- a/jax/experimental/pallas/ops/gpu/attention_mgpu.py +++ b/jax/experimental/pallas/ops/gpu/attention_mgpu.py @@ -249,25 +249,26 @@ def attention_reference(q, k, v): def main(unused_argv): - num_q_heads = 1 - num_kv_heads = 1 - problem_it = itertools.product((1, 2), (4096, 32768,), (64, 128, 256,)) + num_q_heads = 16 + num_kv_heads = 16 + problem_it = itertools.product((1,), (4096, 32768,), (64, 128, 256,)) for batch_size, seq_len, head_dim in problem_it: q_seq_len = kv_seq_len = seq_len print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}" f"{num_q_heads=:<4} {head_dim=:<6} ====") - param_it = itertools.product((64,), (64, 128, 256)) - best = None k1, k2, k3 = jax.random.split(jax.random.key(42), 3) q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16) k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16) - for block_q, block_kv in param_it: + block_q = 64 + best = None + for block_kv in (256, 128, 64): config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2) try: out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v) - out_ref = attention_reference(q, k, v) - np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) + if seq_len < 32768: + out_ref = attention_reference(q, k, v) + np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3) except ValueError as e: if "exceeds available shared memory" in e.args[0]: continue @@ -285,6 +286,7 @@ def main(unused_argv): ) if best is None or runtime_us < best[0]: best = (runtime_us, achieved_tc_util) + break # Remove this for full autotuning. if best is not None: print(f"Best: {best[0]:<7.1f}us = {best[1]:4.1f}% TC utilization")