Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactoring of GPT.forward when it comes to input_pos and KV cache usage #1898

Open
mseeger opened this issue Jan 6, 2025 · 6 comments
Open
Labels
enhancement New feature or request

Comments

@mseeger
Copy link
Contributor

mseeger commented Jan 6, 2025

The current GPT.forward in model.py essentially serves two use cases:

  • Forward pass for training: input_pos=None, KV cache not used. Implicitly, input_pos = arange(idx.shape[-1]), and causal masking is used. Could also be used for prefill with prompt in inference.
  • Inference. input_pos is not None, KV cache is used. There seem two cases here, either input_pos = arange(idx.shape[-1]) (used for prefill), or input_pos.shape[-1] == 1 (generation of single next token, possibly batched).

I am interested in implementing KV cache strategies, such as H2O. In inference, we really only have prefill, and then single-token generation. Inference always works like this:

  • Prefill with sequence length T (minimum of prompt size and max cache size)
  • Generate token T
  • Generate token T+1
  • ...

Most KV cache strategies only support this protocol.

My proposal would be to refactor GPT.forward to support two cases only:

  • Forward pass for training: With an additional flag, this can be used for prefill, in that this would initialize the KV cache with the K and V vectors obtained as part of the forward, just because scaled_dot_product_attention is called
  • Generate single token: idx.shape[-1] == 1. input_pos is not really needed, it would rather be input_pos_maxp1. The KV cache tracks the position of the next token, and it would complain if asked to do anything else

This supports everything you have right now, plus it supports advanced KV caches like H2O.

I am happy to do this in a branch in my fork and show you how it would look like.

@mseeger mseeger added the enhancement New feature or request label Jan 6, 2025
@mseeger
Copy link
Contributor Author

mseeger commented Jan 6, 2025

I'd also implement batched generation based on this, which seems incomplete at the moment.

@t-vi
Copy link
Contributor

t-vi commented Jan 7, 2025

Hi @mseeger ,

great project and thanks for looking into it!

So two things, and this is from a batch perspective mainly:

  • Currently can batch with padding stuff if we prefill things of different lengths (or have a prefill/run combination). Ideally, we would not want to lose that ability.
  • Similarly, we may have external reasons to pad things and/or see the unused bits of the kvcache, so I would not slice the kvcache more than needed. Dynamic shapes can still be tricky with DL compilers, so it would make things easier if this padding would still work.
  • It may be interesting to keep track of the maximum input_pos somewhere, maybe @ali-alshaar7 has thoughts around that (I would have to look).

I would be very grateful if you could keep @ali-alshaar7 in the loop.

Again it's awesome to have you interested in KVCache!

@mseeger
Copy link
Contributor Author

mseeger commented Jan 8, 2025

Do you mean padding prompts of different length?

My first attempt would be to prefill up to the minimum length over prompts, and then go token by token, sampling when a prompt has been processed, and taking from prompt otherwise. In a later stage, one could prefill more (with padding the shorter prompts), and then remove the KV vectors corresponding to padding again. But this is more tricky to implement.

@mseeger
Copy link
Contributor Author

mseeger commented Jan 8, 2025

Your comment on compiling a graph for inference. I think this would still work even with dynamic KV caching like H2O, because it only uses operators like argmin, topk, and scatter. The sizes and shapes of all arrays are determined up front and the same for every call. The only exception I can see is the stopping when encountering , but that is the same w/o KV caching.

And the default KV cache, if you don't specify any, will be the dense one which stores everything and allocates full memory up front. So the default behavior should not change.

@mseeger
Copy link
Contributor Author

mseeger commented Jan 10, 2025

I am not familiar with using compiled (static) graphs for inference, but I suppose this would mean that you need to generate tokens up to the maximum length, even if all sequences in the batch have emitted , right? That sounds pretty inefficient to me. One could do some tricks, of course, such as chunking the max_len into pieces, start with the shortest one, and move up until all sequences have emitted . But that sounds not simple at all.

@t-vi
Copy link
Contributor

t-vi commented Jan 10, 2025

you need to generate tokens up to the maximum length

Well, so the rule is basically that the launch configuration of and parameters to the GPU kernel calls can't change.
So for example, writing a kernel that skips padding markers works, or one that reads the max length from a tensor and then only generates things up to that. Also alignment can be very important, so having funny sequence lengths might be slower than padding. (There is that famous tweet by A. Karpathy about padding the vocabulary size giving him a 20% speedup or so.)
So, basically, I just wanted to mention that there is some nuance to "smaller tensor is better".
Also, in contrast to the input itself, having a KVCache buffer larger than needed does not affect the performance if the attention implementation does not look at masked tokens much: You'll just copy the update into a larger buffer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants