Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[experimental][FP16] Add native __half support for sum_functor #1655

Open
wants to merge 1 commit into
base: release/2.4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions aten/src/ATen/native/cuda/ReduceSumProdKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,11 @@ template <
typename GeneralDispatcher>
static void reduce_dispatch(TensorIterator& iter, GeneralDispatcher op) {
if (iter.dtype() == kHalf) {
#ifdef PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF
return OpFunctor<at::Half, at::Half>{}(iter);
#else
return OpFunctor<at::Half, float>{}(iter);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to ifdef this also.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch -- Thank you!

#endif
} else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel
return OpFunctor<at::Half, float, float>{}(iter);
Expand Down
5 changes: 5 additions & 0 deletions c10/util/Half-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ inline __device__ Half __ldg(const Half* ptr) {
/// Arithmetic

inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) {
#if (defined(__CUDACC__) || defined(__HIPCC__)) && \
defined(PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF)
return __half{a} + __half{b};
#else
return static_cast<float>(a) + static_cast<float>(b);
#endif
}

inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) {
Expand Down
19 changes: 17 additions & 2 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,12 @@ if(USE_CUDNN)
target_include_directories(torch::cudnn INTERFACE ${CUDNN_FRONTEND_INCLUDE_DIR})
endif()

# Note: This variable also affects CUDA.
set(PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF
$ENV{PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF}
CACHE BOOL "Enable native support for half data type within ReduceSum." FORCE)


# ---[ HIP
if(USE_ROCM)
# This prevents linking in the libtinfo from /opt/conda/lib which conflicts with ROCm libtinfo.
Expand Down Expand Up @@ -1042,7 +1048,11 @@ if(USE_ROCM)
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1)
list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1)
list(APPEND HIP_CXX_FLAGS -DUSE_ROCM)
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1)
if(NOT PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF)
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1)
else()
add_definitions(-DPYTORCH_REDUCESUM_ENABLE_NATIVE_HALF)
endif()
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1)
list(APPEND HIP_CXX_FLAGS -DTORCH_HIP_VERSION=${TORCH_HIP_VERSION})
list(APPEND HIP_CXX_FLAGS -Wno-shift-count-negative)
Expand Down Expand Up @@ -1369,11 +1379,16 @@ if(NOT INTERN_BUILD_MOBILE)

message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor")
string(APPEND CMAKE_CUDA_FLAGS " -DCUDA_HAS_FP16=1"
" -D__CUDA_NO_HALF_OPERATORS__"
" -D__CUDA_NO_HALF_CONVERSIONS__"
" -D__CUDA_NO_HALF2_OPERATORS__"
" -D__CUDA_NO_BFLOAT16_CONVERSIONS__")

if(NOT PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF)
string(APPEND CMAKE_CUDA_FLAGS " -D__CUDA_NO_HALF_OPERATORS__")
else()
add_definitions(-DPYTORCH_REDUCESUM_ENABLE_NATIVE_HALF)
endif()

string(APPEND CMAKE_C_FLAGS_RELEASE " -DNDEBUG")
string(APPEND CMAKE_CXX_FLAGS_RELEASE " -DNDEBUG")
if(NOT GENERATOR_IS_MULTI_CONFIG)
Expand Down
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,10 @@
# USE_ROCM_KERNEL_ASSERT=1
# Enable kernel assert in ROCm platform
#
# PYTORCH_REDUCESUM_ENABLE_NATIVE_HALF
# If set to '1' will enable native support for FP16 datatypes in certain functors.
# Note: Currently, this is considered experimental and will only affect reductions.
#
# Environment variables we respect (these environment variables are
# conventional and are often understood/set by other software.)
#
Expand Down Expand Up @@ -676,6 +680,11 @@ def run(self):
else:
report("-- Not using ITT")

if cmake_cache_vars["PYTORCH_ENABLE_HALF"]:
report("-- Using native FP16 support")
else:
report("-- Not using native FP16 support")

# Do not use clang to compile extensions if `-fstack-clash-protection` is defined
# in system CFLAGS
c_flags = str(os.getenv("CFLAGS", ""))
Expand Down