Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Why are attention_scores computed pre masking? #63

Open
samefarrar opened this issue Oct 10, 2024 · 0 comments
Open

Why are attention_scores computed pre masking? #63

samefarrar opened this issue Oct 10, 2024 · 0 comments

Comments

@samefarrar
Copy link

I've been trying to figure out attention_entropy, please let me know if I've misunderstood.

def sample(gen_tokens: jax.Array, logits: jax.Array, attention_scores: jax.Array, cfg: SamplerConfig,
           clarifying_question_token: int = 2564, key=jax.random.PRNGKey(1337)) -> jax.Array:
    print(attention_scores)
    metrics = calculate_metrics(logits, attention_scores)
    ent, vent = metrics["logits_entropy"], metrics["logits_varentropy"]

Scores include attention to the whole possible sequence length (shape: (1, 32, 1, 4096)) with future tokens having a score of 0:

[[[[ 1.8297775   1.8289843  -4.257527   ...  0.          0.
     0.        ]]
  [[ 1.7098694   1.709356   -3.5842574  ...  0.          0.
     0.        ]]
  [[ 0.830071    0.82948154 -3.2098978  ...  0.          0.
     0.        ]]

This seems to inflate attention_entropy, and make attention_varentropy really low. Intuitively, as output sequences get closer to the maximum sequence length, attention entropy will massively collapse as there are fewer 0s in the scores, and the attention_varentropy will increase. This also means that attention_probs for future tokens are non-zero in calculate_metrics() after the softmax. This may make the thresholds set in the frog branch sampler.py worse calibrated as sequences get longer.

attn_entropy_seq_len

# Attention Entropy Thresholds
low_attention_entropy_threshold: float = 11.915
medium_attention_entropy_threshold: float = 11.921
high_attention_entropy_threshold: float = 11.926

In testing (in the jax ipynb), I've found that masking the future attention scores with your default mask value pushes attention_entropy and attention_varentropy to slightly more stable (if not increasing, which is what you'd expect if there are more tokens to attend to), with higher sequence lengths.

logits, kvcache, scores, stats = xfmr(xfmr_weights, model_params, next_token, cur_pos, freqs_cis[cur_pos:cur_pos+1], kvcache)
mask = jnp.arange(scores.shape[-1]) >= cur_pos
# Expand mask to match scores shape: (1, 32, 1, 4096)
mask = mask.reshape(1, 1, 1, -1)
scores = jnp.where(mask, DEFAULT_MASK_VALUE, scores)

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant