Skip to content

Commit

Permalink
[Pallas:MGPU] Make the shapes from the attention example more interes…
Browse files Browse the repository at this point in the history
…ting

This bumps up the number of heads and removes the batch_size=2 case: it's
very similar to batch_size=1 and doubles the script runtime. We also don't
do full autotuning by default since the largest size that works usually
performs the best.

PiperOrigin-RevId: 701976192
  • Loading branch information
apaszke authored and Google-ML-Automation committed Dec 2, 2024
1 parent aff7714 commit 8a31619
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions jax/experimental/pallas/ops/gpu/attention_mgpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,25 +249,26 @@ def attention_reference(q, k, v):


def main(unused_argv):
num_q_heads = 1
num_kv_heads = 1
problem_it = itertools.product((1, 2), (4096, 32768,), (64, 128, 256,))
num_q_heads = 16
num_kv_heads = 16
problem_it = itertools.product((1,), (4096, 32768,), (64, 128, 256,))
for batch_size, seq_len, head_dim in problem_it:
q_seq_len = kv_seq_len = seq_len
print(f"==== {batch_size=:<6} {kv_seq_len=:<6} {q_seq_len=:<6}"
f"{num_q_heads=:<4} {head_dim=:<6} ====")
param_it = itertools.product((64,), (64, 128, 256))
best = None
k1, k2, k3 = jax.random.split(jax.random.key(42), 3)
q = jax.random.normal(k1, (batch_size, q_seq_len, num_q_heads, head_dim), jnp.float16)
k = jax.random.normal(k2, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
v = jax.random.normal(k3, (batch_size, kv_seq_len, num_kv_heads, head_dim), jnp.float16)
for block_q, block_kv in param_it:
block_q = 64
best = None
for block_kv in (256, 128, 64):
config = TuningConfig(block_q=block_q, block_kv=block_kv, max_concurrent_steps=2)
try:
out, runtime_ms = profiler.measure(functools.partial(attention, config=config), q, k, v)
out_ref = attention_reference(q, k, v)
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
if seq_len < 32768:
out_ref = attention_reference(q, k, v)
np.testing.assert_allclose(out, out_ref, atol=2e-3, rtol=1e-3)
except ValueError as e:
if "exceeds available shared memory" in e.args[0]:
continue
Expand All @@ -285,6 +286,7 @@ def main(unused_argv):
)
if best is None or runtime_us < best[0]:
best = (runtime_us, achieved_tc_util)
break # Remove this for full autotuning.
if best is not None:
print(f"Best: {best[0]:<7.1f}us = {best[1]:4.1f}% TC utilization")

Expand Down

0 comments on commit 8a31619

Please sign in to comment.