You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, StableHLO has a while op for unbounded while loops. It also gets used for bounded for loops, e.g. by JAX. Unfortunately, each iteration of a while requires a kernel launch on GPU, which slows down programs significantly.
This can be avoided by unrolling and inling the entire loop. The downside of that approach is that it causes extremely long compilation times, as explained in the link above. (If we just inline the entire loop, the compiler downstream apparently can't easily "see" that it's a loop of the same thing being done over and over again.)
To get the best of both worlds (fast execution and fast compilation) a potential solution would be to create a new primitive op for bounded loops (with a known fixed number of iterations) that does not require multiple kernel launches or returning control to the host. Call it for. Then the body can be optimized while still being treated as an atomic unit for the purposes of compilation downstream, thus avoiding an extremely long compilation time for long sequences.
By exposing the bounded-loop structure directly to the compiler, this could also potentially lead to further optimizations downstream.
The text was updated successfully, but these errors were encountered:
carlosgmartin
changed the title
Add an operation for bounded-length for loops
Add a primitive operation for bounded-length for loops
Dec 9, 2024
Request description
Currently, StableHLO has a
while
op for unbounded while loops. It also gets used for bounded for loops, e.g. by JAX. Unfortunately, each iteration of awhile
requires a kernel launch on GPU, which slows down programs significantly.For context, see:
This can be avoided by unrolling and inling the entire loop. The downside of that approach is that it causes extremely long compilation times, as explained in the link above. (If we just inline the entire loop, the compiler downstream apparently can't easily "see" that it's a loop of the same thing being done over and over again.)
To get the best of both worlds (fast execution and fast compilation) a potential solution would be to create a new primitive op for bounded loops (with a known fixed number of iterations) that does not require multiple kernel launches or returning control to the host. Call it
for
. Then the body can be optimized while still being treated as an atomic unit for the purposes of compilation downstream, thus avoiding an extremely long compilation time for long sequences.By exposing the bounded-loop structure directly to the compiler, this could also potentially lead to further optimizations downstream.
The text was updated successfully, but these errors were encountered: