Keras 3.2.0
What changed
- Introduce QLoRA-like technique for LoRA fine-tuning of
Dense
andEinsumDense
layers (thereby any LLM) in int8 precision. - Extend
keras.ops.custom_gradient
support to PyTorch. - Add
keras.layers.JaxLayer
andkeras.layers.FlaxLayer
to wrap JAX/Flax modules as Keras layers. - Allow
save_model
&load_model
to accept a file-like object. - Add quantization support to the
Embedding
layer. - Make it possible to update metrics inside a custom
compute_loss
method with all backends. - Make it possible to access
self.losses
inside a customcompute_loss
method with the JAX backend. - Add
keras.losses.Dice
loss. - Add
keras.ops.correlate
. - Make it possible to use cuDNN LSTM & GRU with a mask with the TensorFlow backend.
- Better JAX support in
model.export()
: add support for aliases, finer control overjax2tf
options, and dynamic batch shapes. - Bug fixes and performance improvements.
New Contributors
- @abhaskumarsinha made their first contribution in #19302
- @qaqland made their first contribution in #19378
- @tvogel made their first contribution in #19310
- @lpizzinidev made their first contribution in #19409
- @Murhaf made their first contribution in #19444
Full Changelog: v3.1.1...v3.2.0