Skip to content

Commit

Permalink
Optimize MeanVarDistributionDistance (#697)
Browse files Browse the repository at this point in the history
* Fix probability data type

Signed-off-by: Lukas Heumos <[email protected]>

* Optimize mean_var distance

Signed-off-by: Lukas Heumos <[email protected]>

---------

Signed-off-by: Lukas Heumos <[email protected]>
  • Loading branch information
Zethson authored Jan 10, 2025
1 parent 1e80db7 commit 3826fa5
Showing 1 changed file with 56 additions and 48 deletions.
104 changes: 56 additions & 48 deletions pertpy/tools/_distances/_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,67 +1117,75 @@ def __init__(self) -> None:
super().__init__()
self.accepts_precomputed = False

@staticmethod
def _mean_var(x, log: bool = False):
mean = np.mean(x, axis=0)
var = np.var(x, axis=0)
positive = mean > 0
mean = mean[positive]
var = var[positive]
if log:
mean = np.log(mean)
var = np.log(var)
return mean, var

@staticmethod
def _prep_kde_data(x, y):
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)

@staticmethod
def _grid_points(d, n_points=100):
# Make grid, add 1 bin on lower/upper end to get final n_points
d_min = d.min()
d_max = d.max()
# Compute bin size
d_bin = (d_max - d_min) / (n_points - 2)
d_min = d_min - d_bin
d_max = d_max + d_bin
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)

@staticmethod
def _kde_eval_both(x_kde, y_kde, grid):
n_points = len(grid)
chunk_size = 10000

result_x = np.zeros(n_points)
result_y = np.zeros(n_points)

# Process same chunks for both KDEs
for start in range(0, n_points, chunk_size):
end = min(start + chunk_size, n_points)
chunk = grid[start:end]
result_x[start:end] = x_kde.score_samples(chunk)
result_y[start:end] = y_kde.score_samples(chunk)

return result_x, result_y

def __call__(self, X: np.ndarray, Y: np.ndarray, **kwargs) -> float:
"""Difference of mean-var distributions in 2 matrices.
Args:
X: Normalized and log transformed cells x genes count matrix.
Y: Normalized and log transformed cells x genes count matrix.
"""
mean_x, var_x = self._mean_var(X, log=True)
mean_y, var_y = self._mean_var(Y, log=True)

def _mean_var(x, log: bool = False):
mean = np.mean(x, axis=0)
var = np.var(x, axis=0)
positive = mean > 0
mean = mean[positive]
var = var[positive]
if log:
mean = np.log(mean)
var = np.log(var)
return mean, var

def _prep_kde_data(x, y):
return np.concatenate([x.reshape(-1, 1), y.reshape(-1, 1)], axis=1)

def _grid_points(d, n_points=100):
# Make grid, add 1 bin on lower/upper end to get final n_points
d_min = d.min()
d_max = d.max()
# Compute bin size
d_bin = (d_max - d_min) / (n_points - 2)
d_min = d_min - d_bin
d_max = d_max + d_bin
return np.arange(start=d_min + 0.5 * d_bin, stop=d_max, step=d_bin)

def _parallel_score_samples(kde, samples, thread_count=int(0.875 * multiprocessing.cpu_count())):
# the thread_count is determined using the factor 0.875 as recommended here:
# https://stackoverflow.com/questions/32625094/scipy-parallel-computing-in-ipython-notebook
with multiprocessing.Pool(thread_count) as p:
return np.concatenate(p.map(kde.score_samples, np.array_split(samples, thread_count)))

def _kde_eval(d, grid):
# Kernel choice: Gaussian is too smoothing and cosine or other kernels that do not stretch out
# can not be compared well on regions further away from the data as they are -inf
kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(d)
return _parallel_score_samples(kde, grid)

mean_x, var_x = _mean_var(X, log=True)
mean_y, var_y = _mean_var(Y, log=True)

x = _prep_kde_data(mean_x, var_x)
y = _prep_kde_data(mean_y, var_y)
x = self._prep_kde_data(mean_x, var_x)
y = self._prep_kde_data(mean_y, var_y)

# Gridpoints to eval KDE on
mean_grid = _grid_points(np.concatenate([mean_x, mean_y]))
var_grid = _grid_points(np.concatenate([var_x, var_y]))
mean_grid = self._grid_points(np.concatenate([mean_x, mean_y]))
var_grid = self._grid_points(np.concatenate([var_x, var_y]))
grid = np.array(np.meshgrid(mean_grid, var_grid)).T.reshape(-1, 2)

kde_x = _kde_eval(x, grid)
kde_y = _kde_eval(y, grid)
# Fit both KDEs first
x_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(x)
y_kde = KernelDensity(bandwidth="silverman", kernel="exponential").fit(y)

kde_diff = ((kde_x - kde_y) ** 2).mean()
# Evaluate both KDEs on same grid chunks
kde_x, kde_y = self._kde_eval_both(x_kde, y_kde, grid)

return kde_diff
return ((np.exp(kde_x) - np.exp(kde_y)) ** 2).mean()

def from_precomputed(self, P: np.ndarray, idx: np.ndarray, **kwargs) -> float:
raise NotImplementedError("MeanVarDistributionDistance cannot be called on a pairwise distance matrix.")
Expand Down

0 comments on commit 3826fa5

Please sign in to comment.