Skip to content

LouisDesdoigts/zodiax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

69 Commits
Jan 21, 2025
Jan 21, 2025
Jan 21, 2025
Jan 21, 2025
Nov 14, 2022
Jan 14, 2025
Mar 20, 2023
Mar 20, 2023
Nov 14, 2022
Nov 14, 2022
Jan 21, 2025
May 7, 2023
Jan 21, 2025
Jan 14, 2025

Repository files navigation

Zodiax

PyPI version License integration Documentation


Zodiax is a lightweight extension to the object-oriented Jax framework Equinox. Equinox allows for differentiable classes that are recognised as a valid Jax type and Zodiax adds lightweight methods to simplify interfacing with these classes! Zodiax was originially built in the development of dLux and was designed to make working with large nested classes structures simple and flexible.

Zodiax is directly integrated with both Jax and Equinox, gaining all of their core features:

Documentation: louisdesdoigts.github.io/zodiax/

Contributors: Louis Desdoigts

Requires: Python 3.9+, Jax 0.4.25+

Installation: pip install zodiax

Docs installation: pip install "zodiax[docs]"

Test installation: pip install "zodiax[tests]"


Quickstart

Create a regular class that inherits from zodiax.Base

import jax
import zodiax as zdx
import jax.numpy as np

class Linear(zdx.Base):
    m : Jax.Array
    b : Jax.Array

    def __init__(self, m, b):
        self.m = m
        self.b = b

    def model(self, x):
        return self.m * x + self.b

linear = Linear(1., 1.)

Its that simple! The linear class is now a fully differentiable object that gives us all the benefits of jax with an object-oriented interface! Lets see how we can jit-compile and take gradients of this class.

@jax.jit
@jax.grad
def loss_fn(model, xs, ys):
    return np.square(model.model(xs) - ys).sum()

xs = np.arange(5)
ys = 2*np.arange(5)
grads = loss_fn(linear, xs, ys)
print(grads)
print(grads.m, grads.b)
> Linear(m=f32[], b=f32[])
> -40.0 -10.0

The grads object is an instance of the Linear class with the gradients of the parameters with respect to the loss function!