From 3826fa5f4dbb9b76408fd7736025cd687ea92020 Mon Sep 17 00:00:00 2001 From: Lukas Heumos Date: Fri, 10 Jan 2025 21:40:04 +0100 Subject: [PATCH] Optimize MeanVarDistributionDistance (#697) * Fix probability data type Signed-off-by: Lukas Heumos * Optimize mean_var distance Signed-off-by: Lukas Heumos --------- Signed-off-by: Lukas Heumos --- pertpy/tools/_distances/_distances.py | 104 ++++++++++++++------------ 1 file changed, 56 insertions(+), 48 deletions(-) diff --git a/pertpy/tools/_distances/_distances.py b/pertpy/tools/_distances/_distances.py index ec9001d4..b0b56a61 100644 --- a/pertpy/tools/_distances/_distances.py +++ b/pertpy/tools/_distances/_distances.py @@ -1117,67 +1117,75 @@ def __init__(self) -> None: super().__init__() self.accepts_precomputed = False + @staticmethod + def _mean_var(x, log: bool = False): + mean = np.mean(x, axis=0) + var = np.var(x, axis=0) + positive = mean > 0 + mean = mean[positive] + var = var[positive] + if log: + mean = np.log(mean) + var = np.log(var) + return mean, var + + @staticmethod + def _prep_kde_data(x, y): + return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1) + + @staticmethod + def _grid_points(d, n_points=100): + # Make grid, add 1 bin on lower/upper end to get final n_points + d_min = d.min() + d_max = d.max() + # Compute bin size + d_bin = (d_max - d_min) / (n_points - 2) + d_min = d_min - d_bin + d_max = d_max + d_bin + return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin) + + @staticmethod + def _kde_eval_both(x_kde, y_kde, grid): + n_points = len(grid) + chunk_size = 10000 + + result_x = np.zeros(n_points) + result_y = np.zeros(n_points) + + # Process same chunks for both KDEs + for start in range(0, n_points, chunk_size): + end = min(start + chunk_size, n_points) + chunk = grid[start:end] + result_x[start:end] = x_kde.score_samples(chunk) + result_y[start:end] = y_kde.score_samples(chunk) + + return result_x, result_y + def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float: """Difference of mean-var distributions in 2 matrices. - Args: X: Normalized and log transformed cells x genes count matrix. Y: Normalized and log transformed cells x genes count matrix. """ + mean_x, var_x = self._mean_var(X, log=True) + mean_y, var_y = self._mean_var(Y, log=True) - def _mean_var(x, log: bool = False): - mean = np.mean(x, axis=0) - var = np.var(x, axis=0) - positive = mean > 0 - mean = mean[positive] - var = var[positive] - if log: - mean = np.log(mean) - var = np.log(var) - return mean, var - - def _prep_kde_data(x, y): - return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1) - - def _grid_points(d, n_points=100): - # Make grid, add 1 bin on lower/upper end to get final n_points - d_min = d.min() - d_max = d.max() - # Compute bin size - d_bin = (d_max - d_min) / (n_points - 2) - d_min = d_min - d_bin - d_max = d_max + d_bin - return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin) - - def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())): - # the thread_count is determined using the factor 0.875 as recommended here: - # https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook - with multiprocessing.Pool(thread_count) as p: - return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count))) - - def _kde_eval(d, grid): - # Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out - # can not be compared well on regions further away from the data as they are -inf - kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d) - return _parallel_score_samples(kde, grid) - - mean_x, var_x = _mean_var(X, log=True) - mean_y, var_y = _mean_var(Y, log=True) - - x = _prep_kde_data(mean_x, var_x) - y = _prep_kde_data(mean_y, var_y) + x = self._prep_kde_data(mean_x, var_x) + y = self._prep_kde_data(mean_y, var_y) # Gridpoints to eval KDE on - mean_grid = _grid_points(np.concatenate([mean_x, mean_y])) - var_grid = _grid_points(np.concatenate([var_x, var_y])) + mean_grid = self._grid_points(np.concatenate([mean_x, mean_y])) + var_grid = self._grid_points(np.concatenate([var_x, var_y])) grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2) - kde_x = _kde_eval(x, grid) - kde_y = _kde_eval(y, grid) + # Fit both KDEs first + x_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(x) + y_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(y) - kde_diff = ((kde_x - kde_y) ** 2).mean() + # Evaluate both KDEs on same grid chunks + kde_x, kde_y = self._kde_eval_both(x_kde, y_kde, grid) - return kde_diff + return ((np.exp(kde_x) - np.exp(kde_y)) ** 2).mean() def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float: raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")