Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How about ChamferDistance instead of cdist to calculate KNN? #64

Open
Kitsunetic opened this issue Jul 26, 2023 · 2 comments
Open

How about ChamferDistance instead of cdist to calculate KNN? #64

Kitsunetic opened this issue Jul 26, 2023 · 2 comments

Comments

@Kitsunetic
Copy link

Kitsunetic commented Jul 26, 2023

Hi, I propose to use KNN function that is used in point cloud field instead of torch.cdist.

VectorQuantization finds nearest vector inside of the codebook (B, M, D) with given vector input (B, N, D).

This repository is using torch.cdist to calculate nearest codebook (e.g. (B, N, D), (B, M, D) -> index of (B, N)) in here to get similarity matrix (B, N, M).

This method requires to preserve similarity matrix (B, N, M) in the memory. However this is inefficient if the N or M became larger. Unlike this, during the calculation of ChanferDistance, it does not preserve the full similarity matrix (B, N, M) in the memory using reduction operation of CUDA. For example, it gets two input sequence of vectors (B, N, D) and (B, M, D) and it can directly return (B, N, D), and (B, N; int64), where the first output (B, N, D) has same shape with the input vector but its values are of the codebook, and (B, N; int64) is the indices of nearest codebook. It is similar with how memory efficient attention reduces required memory of Transformer calculation.

This KNN implementation is available in off-the-shelf library such as pytorch3d.ops.knn(), and it is differentiable and DDP-safe, and it works on both CPU and GPU.

@lucidrains
Copy link
Owner

lucidrains commented Jul 30, 2023

@Kitsunetic hey! thanks for proposing this

i was actually looking into kmeans++

so basically this is identical to kmeans except more memory efficient?

@lucidrains
Copy link
Owner

@Kitsunetic since kmeans is only calculated at the start, i don't think it really affects training all that much?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants