diff --git a/README.md b/README.md index ca8f31b..c20f43f 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,10 @@ Porting [FlexAttention](https://github.com/pytorch-labs/attention-gym) to pure JAX. -Please install Jax nightly: pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +Please install Jax nightly: +```bash +pip install -U --pre jax jaxlib "jax-cuda12-plugin[with_cuda]" jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html +``` Example usage: