Skip to content

Commit

Permalink
test autograd
Browse files Browse the repository at this point in the history
  • Loading branch information
zinccat committed Oct 22, 2024
1 parent 334d622 commit 0b88e59
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,81 @@ def test_score_mod(self):
)

self.assertEqual(output.shape, (batch_size, num_heads, seq_len_q, feature_size))

def test_autograd(self):
# Prepare inputs
batch_size = 4
num_heads = 8
seq_len_q = 64
seq_len_kv = 64
feature_size = 32

# Random tensors for query, key, and value
key = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, num_heads, seq_len_kv, feature_size)
)
query = jax.random.normal(
jax.random.PRNGKey(1), (batch_size, num_heads, seq_len_q, feature_size)
)
value = jax.random.normal(
jax.random.PRNGKey(2), (batch_size, num_heads, seq_len_kv, feature_size)
)

def fn(query, key, value):
return flax_attention(
query,
key,
value,
).sum()

grad_fn = jax.grad(fn, 0)
grad = grad_fn(query, key, value)

self.assertEqual(grad.shape, (batch_size, num_heads, seq_len_q, feature_size))

def test_autograd_equivalence_with_torch(self):
# Prepare inputs
batch_size = 4
num_heads = 8
seq_len_q = 64
seq_len_kv = 64
feature_size = 32

# Random tensors for query, key, and value
key = jax.random.normal(
jax.random.PRNGKey(0), (batch_size, num_heads, seq_len_kv, feature_size)
)
query = jax.random.normal(
jax.random.PRNGKey(1), (batch_size, num_heads, seq_len_q, feature_size)
)
value = jax.random.normal(
jax.random.PRNGKey(2), (batch_size, num_heads, seq_len_kv, feature_size)
)

def fn(query, key, value):
return flax_attention(
query,
key,
value,
).sum()

grad_fn = jax.grad(fn, 0)
grad_jax = grad_fn(query, key, value)

query_torch = jax2torch(query)
key_torch = jax2torch(key)
value_torch = jax2torch(value)

query_torch.requires_grad = True

output_torch = flex_attention(
query_torch,
key_torch,
value_torch,
).sum()

output_torch.backward()

grad_torch = query_torch.grad.cpu().numpy()

np.testing.assert_almost_equal(grad_jax, grad_torch, decimal=3)

0 comments on commit 0b88e59

Please sign in to comment.