From cf938d2c6dbe63a35ee531aaa0d783ef509ab82e Mon Sep 17 00:00:00 2001 From: Achille Nazaret Date: Mon, 4 Nov 2024 15:38:38 -0500 Subject: [PATCH 1/2] import Sample in main __init__ (to make it discoverable for the doc) --- src/treeffuser/__init__.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/treeffuser/__init__.py b/src/treeffuser/__init__.py index 17361ec2..6c548670 100644 --- a/src/treeffuser/__init__.py +++ b/src/treeffuser/__init__.py @@ -1,4 +1,5 @@ __version__ = "0.1.3" +from treeffuser.samples import Samples from treeffuser.treeffuser import Treeffuser -__all__ = ["Treeffuser"] +__all__ = ["Treeffuser", "Samples"] From f9902146997e0182ed33990c98afff4d99f83675 Mon Sep 17 00:00:00 2001 From: Achille Nazaret Date: Mon, 4 Nov 2024 15:40:18 -0500 Subject: [PATCH 2/2] use direct import for simpler code --- src/treeffuser/samples.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/treeffuser/samples.py b/src/treeffuser/samples.py index 216f0c48..73f88c6d 100644 --- a/src/treeffuser/samples.py +++ b/src/treeffuser/samples.py @@ -4,8 +4,8 @@ from typing import Union import numpy as np -import sklearn from jaxtyping import Float +from sklearn.neighbors import KernelDensity from tqdm import tqdm @@ -118,7 +118,7 @@ def sample_kde( self, bandwidth: Union[float, Literal["scott", "silverman"]] = 1.0, verbose: bool = False, - ) -> List[sklearn.neighbors.KernelDensity]: + ) -> List[KernelDensity]: """ Compute the Kernel Density Estimate (KDE) for each `x`. Estimate: `KDE[Y | X = x]` for each `x` using Gaussian kernels from `sklearn.neighbors`. @@ -135,8 +135,8 @@ def sample_kde( Returns ------- - kdes : list of sklearn.neighbors.KernelDensity - A list of `sklearn.neighbors.KernelDensity` objects, one for each `x`. + kdes : list of KernelDensity + A list of `KernelDensity` objects, one for each `x`. """ kdes = [] for i in tqdm( @@ -148,9 +148,7 @@ def sample_kde( y_i = self._samples[:, i, None] else: y_i = self._samples[:, i, :] - kde = sklearn.neighbors.KernelDensity( - bandwidth=bandwidth, algorithm="auto", kernel="gaussian" - ) + kde = KernelDensity(bandwidth=bandwidth, algorithm="auto", kernel="gaussian") kde.fit(y_i) kdes.append(kde)