forked from jax-ml/jax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Pallas][Mosaic GPU] Add support for compressing squeezed dims in asy…
…nc_copy + grid fixes This change removes the need to flatten the batch dimension into sequence dimensions in the flash attention kernel. The critical thing here is the observation that we can in fact collapse all squeezed dimension into a single one in the TMA descriptor, letting us reduce its rank when necessary. Doing this also uncovered some issues with how we were handling the grid in Pallas:MGPU lowering, which I've fixed. PiperOrigin-RevId: 701035277
- Loading branch information
1 parent
d5bfafb
commit b801539
Showing
4 changed files
with
110 additions
and
44 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters