Skip to content

Commit

Permalink
Improve softmax performance.
Browse files Browse the repository at this point in the history
  • Loading branch information
doru1004 committed Dec 2, 2024
1 parent 5672206 commit 5515a5a
Showing 1 changed file with 140 additions and 0 deletions.
140 changes: 140 additions & 0 deletions aten/src/ATen/native/cuda/SoftMax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,26 @@ struct SoftMaxBackwardEpilogue {
const AccumT sum;
};

#ifdef USE_ROCM
template<typename T, typename AccumT, typename OutT>
struct SoftMaxForwardWithMulEpilogue {
__device__ __forceinline__ SoftMaxForwardWithMulEpilogue(AccumT max_input, AccumT sum)
: max_input(max_input)
, sum(sum) {}

__device__ __forceinline__ OutT operator()(T input) const {
#ifdef PYTORCH_USE_EXPF
return static_cast<OutT>(__expf(input - max_input) * sum);
#else
return static_cast<OutT>(std::exp(input - max_input) * sum);
#endif
}

const AccumT max_input;
const AccumT sum;
};
#endif




Expand Down Expand Up @@ -387,6 +407,21 @@ struct SumExpFloat
const AccumT max_k;
};

#ifdef USE_ROCM
template<typename T, typename AccumT>
struct SumExpfFloat
{
__device__ __forceinline__ SumExpfFloat(AccumT v)
: max_k(v) {}

__device__ __forceinline__ AccumT operator()(AccumT sum, T v) const {
return sum + __expf(v - max_k);
}

const AccumT max_k;
};
#endif

template <template<typename> class Reduction, typename AccumT>
__device__ __forceinline__ AccumT
blockReduce(AccumT* smem, AccumT val,
Expand Down Expand Up @@ -449,6 +484,18 @@ T blockReduceWarp(T* smem_cache, T value, const Reduction<T>& op, T defaultVal)
return smem_cache[0];
}

template <template<typename> class Reduction, typename T>
__device__ __forceinline__
T blockReduceWarpInverse(T* smem_cache, T value, const Reduction<T>& op, T defaultVal)
{
T result = cuda_utils::BlockReduce<T, Reduction<T>>(value, op, defaultVal, smem_cache);
if (threadIdx.x == 0) {
smem_cache[0] = 1 / result;
}
__syncthreads();
return smem_cache[0];
}

template <template<typename, typename> class Reduction, int ILP, typename T, typename AccumT, typename index_t=int>
__device__ __forceinline__ AccumT
ilpReduce(index_t shift,
Expand Down Expand Up @@ -694,6 +741,73 @@ cunn_SoftMaxForward(outscalar_t *output, const scalar_t *input, int classes)
}
}

#ifdef USE_ROCM
template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
template <typename, typename, typename> class Epilogue, typename index_t = int32_t>
__global__ void
cunn_SoftMaxForwardGmem(outscalar_t *output, const scalar_t *input, index_t classes)
{
// Each thread block processes a sample in the batch
input += static_cast<int64_t>(blockIdx.x) * classes;
output += static_cast<int64_t>(blockIdx.x) * classes;

accscalar_t threadMax = -at::numeric_limits<accscalar_t>::max();
accscalar_t threadExp = static_cast<accscalar_t>(0);

// The first smem segment is used to cache input values and the last
// segment is used for thread block reductions
extern __shared__ unsigned char smem[];
auto smem_reduction_cache = reinterpret_cast<accscalar_t*>(smem);

using LoadT = at::native::memory::aligned_vector<scalar_t, ILP>;
const LoadT* const input_vec_ptr = reinterpret_cast<const LoadT*>(input);

// Do the first step in max calculation:
MaxFloat<scalar_t, accscalar_t> maxFunc;
for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
LoadT crnt_vec = input_vec_ptr[offset];
#pragma unroll
for (int i = 0; i < ILP; ++i) {
threadMax = maxFunc(threadMax, crnt_vec.val[i]);
}
}

accscalar_t max_k = blockReduceWarp<Max, accscalar_t>(smem_reduction_cache, threadMax,
Max<accscalar_t>(), -at::numeric_limits<accscalar_t>::max());

// Do the second step in sum exp calculation:
#ifdef PYTORCH_USE_EXPF
SumExpfFloat<scalar_t, accscalar_t> sumExpFunc(max_k);
#else
SumExpFloat<scalar_t, accscalar_t> sumExpFunc(max_k);
#endif
for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
LoadT crnt_vec = input_vec_ptr[offset];
#pragma unroll
for (int i = 0; i < ILP; ++i) {
threadExp = sumExpFunc(threadExp, crnt_vec.val[i]);
}
}

accscalar_t sumAll = blockReduceWarpInverse<Add, accscalar_t>(smem_reduction_cache, threadExp,
Add<accscalar_t>(), static_cast<accscalar_t>(0));

Epilogue<scalar_t, accscalar_t, outscalar_t> epilogue(max_k, sumAll);

using StoreT = at::native::memory::aligned_vector<outscalar_t, ILP>;
StoreT* output_vec_ptr = reinterpret_cast<StoreT*>(output);
for (index_t offset = threadIdx.x; offset * ILP < classes; offset += blockDim.x) {
LoadT crnt_vec = input_vec_ptr[offset];
StoreT out_vec;
#pragma unroll
for (int i = 0; i < ILP; ++i) {
out_vec.val[i] = epilogue(crnt_vec.val[i]);
}
output_vec_ptr[offset] = out_vec;
}
}
#endif

template <int ILP, typename scalar_t, typename accscalar_t, typename outscalar_t,
template <typename, typename, typename> class Epilogue, typename index_t = int32_t>
__global__ void
Expand Down Expand Up @@ -816,7 +930,11 @@ cunn_SoftMaxBackward(scalar_t *gradInput, const outscalar_t *output, const outsc
}
}

#ifdef USE_ROCM
template<template<typename, typename, typename> class Epilogue, template<typename, typename, typename> class EpilogueWithMul, bool is_log_softmax>
#else
template<template<typename, typename, typename> class Epilogue, bool is_log_softmax>
#endif
Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_to_float, const Tensor& output){
if (half_to_float) {
TORCH_CHECK(input_.scalar_type() == ScalarType::Half, "conversion is supported for Half type only");
Expand Down Expand Up @@ -858,6 +976,12 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
}
} else {
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
#ifdef USE_ROCM
dim3 block(512);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
cunn_SoftMaxForwardGmem<ILP, scalar_t, accscalar_t, scalar_t, EpilogueWithMul>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
#else
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
Expand All @@ -876,6 +1000,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, scalar_t, Epilogue>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
}
#endif

C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand All @@ -894,6 +1019,12 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
}
} else {
constexpr int ILP = sizeof(float4) / sizeof(scalar_t);
#ifdef USE_ROCM
dim3 block(512);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
cunn_SoftMaxForwardGmem<ILP, scalar_t, accscalar_t, accscalar_t, EpilogueWithMul>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
#else
dim3 block = SoftMaxForward_getBlockSize(dim_size);
size_t smem_reduction_sz = block.x / C10_WARP_SIZE * sizeof(accscalar_t);
auto max_elements_per_smem = (at::cuda::getCurrentDeviceProperties()->sharedMemPerBlock -
Expand All @@ -912,6 +1043,7 @@ Tensor host_softmax(const Tensor & input_, const int64_t dim_, const bool half_t
cunn_SoftMaxForward<ILP, scalar_t, accscalar_t, accscalar_t, Epilogue>
<<<grid, block, smem_reduction_sz, stream>>>(output_ptr, input_ptr, dim_size);
}
#endif

C10_CUDA_KERNEL_LAUNCH_CHECK();
}
Expand Down Expand Up @@ -1069,7 +1201,11 @@ TORCH_IMPL_FUNC(log_softmax_cuda_out) (
const int64_t dim,
const bool half_to_float,
const Tensor &output) {
#ifdef USE_ROCM
host_softmax<LogSoftMaxForwardEpilogue, LogSoftMaxForwardEpilogue, true>(input, dim, half_to_float, output);
#else
host_softmax<LogSoftMaxForwardEpilogue,true>(input, dim, half_to_float, output);
#endif
}

TORCH_IMPL_FUNC(log_softmax_backward_cuda_out) (
Expand All @@ -1093,7 +1229,11 @@ TORCH_IMPL_FUNC(softmax_cuda_out) (
const int64_t dim,
const bool half_to_float,
const Tensor &output) {
#ifdef USE_ROCM
host_softmax<SoftMaxForwardEpilogue, SoftMaxForwardWithMulEpilogue, false>(input, dim, half_to_float, output);
#else
host_softmax<SoftMaxForwardEpilogue,false>(input, dim, half_to_float, output);
#endif
}

TORCH_IMPL_FUNC(softmax_backward_cuda_out)
Expand Down

0 comments on commit 5515a5a

Please sign in to comment.