diff --git a/README.md b/README.md index 17ad1d0..812cc03 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Porting [FlexAttention](https://github.com/pytorch-labs/attention-gym) to pure JAX. -Example usage (For faster performance using Flash Attention, check examples/benchmark.py): +Example usage (**For faster performance using Flash Attention, check examples/benchmark.py**): ```python import jax @@ -69,4 +69,4 @@ Float16: - FlexAttention: 0.11s - FlaxAttention (This repo): 0.13s -We can see that the performance is about 20% slower than the original implementation. There are still some optimizations to be done. \ No newline at end of file +We can see that the performance is about 20% slower than the original implementation. There are still some optimizations to be done.