Skip to content

Commit

Permalink
python311Packages.jax: 0.4.28 -> 0.4.35
Browse files Browse the repository at this point in the history
  • Loading branch information
GaetanLepage committed Oct 23, 2024
1 parent ab2ccef commit 1ad0edc
Showing 1 changed file with 27 additions and 20 deletions.
47 changes: 27 additions & 20 deletions pkgs/development/python-modules/jax/default.nix
Original file line number Diff line number Diff line change
@@ -1,46 +1,52 @@
{
lib,
stdenv,
blas,
lapack,
buildPythonPackage,
fetchFromGitHub,
callPackage,
pythonOlder,

# build-system
setuptools,
importlib-metadata,
fetchFromGitHub,

# dependencies
jaxlib,
jaxlib-bin,
jaxlib-build,
hypothesis,
lapack,
matplotlib,
ml-dtypes,
numpy,
opt-einsum,
scipy,
importlib-metadata,

# nativeCheckInputs
hypothesis,
matplotlib,
pytestCheckHook,
pytest-xdist,
pythonOlder,
scipy,
stdenv,

# passthru
jaxlib-build,
jaxlib-bin,
}:

let
usingMKL = blas.implementation == "mkl" || lapack.implementation == "mkl";
in
buildPythonPackage rec {
pname = "jax";
version = "0.4.28";
version = "0.4.32";
pyproject = true;

disabled = pythonOlder "3.9";

src = fetchFromGitHub {
owner = "google";
repo = "jax";
# google/jax contains tags for jax and jaxlib. Only use jax tags!
rev = "refs/tags/jax-v${version}";
hash = "sha256-qSHPwi3is6Ts7pz5s4KzQHBMbcjGp+vAOsejW3o36Ek=";
hash = "sha256-eg+uP0ZWHG6R+5UGeAcKyg+v150ANQvD31jYtPkcYc8=";
};

nativeBuildInputs = [ setuptools ];
build-system = [ setuptools ];

# The version is automatically set to ".dev" if this variable is not set.
# https://github.com/google/jax/commit/e01f2617b85c5bdffc5ffb60b3d8d8ca9519a1f3
Expand All @@ -49,7 +55,8 @@ buildPythonPackage rec {
# jaxlib is _not_ included in propagatedBuildInputs because there are
# different versions of jaxlib depending on the desired target hardware. The
# JAX project ships separate wheels for CPU, GPU, and TPU.
propagatedBuildInputs = [
dependencies = [
jaxlib
ml-dtypes
numpy
opt-einsum
Expand All @@ -58,7 +65,6 @@ buildPythonPackage rec {

nativeCheckInputs = [
hypothesis
jaxlib
matplotlib
pytestCheckHook
pytest-xdist
Expand Down Expand Up @@ -158,14 +164,15 @@ buildPythonPackage rec {
# updater fails to pick the correct branch
passthru.skipBulkUpdate = true;

meta = with lib; {
meta = {
description = "Source-built JAX frontend: differentiate, compile, and transform Numpy code";
longDescription = ''
This is the JAX frontend package, it's meant to be used together with one of the jaxlib implementations,
e.g. `python3Packages.jaxlib`, `python3Packages.jaxlib-bin`, or `python3Packages.jaxlibWithCuda`.
'';
homepage = "https://github.com/google/jax";
license = licenses.asl20;
maintainers = with maintainers; [ samuela ];
changelog = "https://github.com/google/jax/releases/tag/jax-v${version}";
license = lib.licenses.asl20;
maintainers = with lib.maintainers; [ samuela ];
};
}

0 comments on commit 1ad0edc

Please sign in to comment.