Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add softmax_csr implementation #264

Merged

Conversation

DamianSzwichtenberg
Copy link
Member

@DamianSzwichtenberg DamianSzwichtenberg commented Oct 12, 2023

This PR adds forward and backward implementation of sparse softmax operation as defined here.

In the pytorch_geometric implementation we cannot take advantage of model compilation when groups are defined via ptr. softmax_csr introduced here provides a well-performing kernel for such a scenario.

Performance boost (achieved on 28C, single socket machine):
~7x for forward pass
~8x for backward pass
Additionally, GAT training time was reduced by ~5%.

@codecov
Copy link

codecov bot commented Oct 12, 2023

Codecov Report

Attention: 4 lines in your changes are missing coverage. Please review.

Comparison is base (2b9af1c) 85.65% compared to head (40c8f52) 86.19%.

Files Patch % Lines
pyg_lib/csrc/ops/cpu/softmax_kernel.cpp 91.11% 4 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master     #264      +/-   ##
==========================================
+ Coverage   85.65%   86.19%   +0.54%     
==========================================
  Files          32       34       +2     
  Lines        1115     1188      +73     
==========================================
+ Hits          955     1024      +69     
- Misses        160      164       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@DamianSzwichtenberg
Copy link
Member Author

@pyg-team/intel-team Please take a look.

[0.0598, 0.2923, 0.1206, 0.0921],
[0.7792, 0.3502, 0.1638, 0.2145]])
"""
if src.dim() != 2 or not src.is_cpu or ptr is None or dim != 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder why ptr is optional because if you don't provide it, you get an error.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to make API in its final form, otherwise, each change here would require a change in pytorch_geometric. I'll add support for index in the near future.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kgajdamo, after rethinking your suggestion, I decided to change the API and create a specialized softmax_csr operation that accepts ptr only. Rationale:
torch.compile gives nice results for softmax with groups defined via index, hence I don't see a reason to have a specialized kernel for that option.

@DamianSzwichtenberg DamianSzwichtenberg changed the title Add sparse softmax implementation Add softmax_csr implementation Nov 3, 2023
@DamianSzwichtenberg
Copy link
Member Author

@kgajdamo @rusty1s Please take a look. I made the softmax implementation a bit more general, so now it covers any src dimensionality as well as any given dim. I also restricted groups to be defined via ptr (rationale here).

@yanbing-j
Copy link
Contributor

Hi @DamianSzwichtenberg , this PR looks good to me. The overall structure of softmax kernel with sparse input is similar with that in softmax kernel of dense input in PyTorch. With the sparsity, the performance boost is from parallelism, right?

And will this PR upstream to PyTorch later? Since there is no SparseCsr support for softmax yet in PyTorch.

@DamianSzwichtenberg
Copy link
Member Author

Hi @DamianSzwichtenberg , this PR looks good to me. The overall structure of softmax kernel with sparse input is similar with that in softmax kernel of dense input in PyTorch. With the sparsity, the performance boost is from parallelism, right?

These kernels differ quite a bit. In softmax_csr groups are created from ptr across given dim and then, for each dimension different from dim we also create seperate group. If you take a look at tests, to achieve the same result with torch.nn.Softmax you need to slice the tensor. As for performance gains, parallelization is made through the groups defined by user (which also differs from Softmax). The other thing is that, for most common scenario (dim=0), we access data in contiguous manner, despite the fact that contiguous elements do not belong to the same group. So performance comes from parallelization and good memory access pattern.

And will this PR upstream to PyTorch later? Since there is no SparseCsr support for softmax yet in PyTorch.

There are no plans to upstream this operation to PyTorch. As above, softmax_csr differs from Softmax operation defined in torch.

Copy link
Contributor

@kgajdamo kgajdamo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me.

@DamianSzwichtenberg DamianSzwichtenberg merged commit 0e787f1 into pyg-team:master Nov 17, 2023
10 checks passed
Comment on lines +334 to +336
class Softmax(torch.autograd.Function):
@staticmethod
def forward(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we define the autograd function directly in C++?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be possible, will check.

Copy link
Member Author

@DamianSzwichtenberg DamianSzwichtenberg Nov 20, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change available at #282

DamianSzwichtenberg added a commit to pyg-team/pytorch_geometric that referenced this pull request Nov 21, 2023
This PR uses optimized `softmax_csr` operation (introduced in [pyg-lib @
264](pyg-team/pyg-lib#264)), when given is a CPU
tensor, and softmax groups are defined via `ptr`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants