Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[experimental][FP16] Add native __half support for sum_functor #1655

Open
wants to merge 1 commit into
base: release/2.4
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/ReduceSumProdKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ template <
typename GeneralDispatcher>
static void reduce_dispatch(TensorIterator& iter, GeneralDispatcher op) {
if (iter.dtype() == kHalf) {
return OpFunctor<at::Half, float>{}(iter);
return OpFunctor<at::Half, at::Half>{}(iter);
} else if (iter.dtype(1) == kHalf && iter.dtype() == kFloat) {
// type promotion that does cast and reduction in a single kernel
return OpFunctor<at::Half, float, float>{}(iter);
Expand Down
5 changes: 5 additions & 0 deletions c10/util/Half-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,12 @@ inline __device__ Half __ldg(const Half* ptr) {
/// Arithmetic

inline C10_HOST_DEVICE Half operator+(const Half& a, const Half& b) {
#if (defined(__CUDACC__) || defined(__HIPCC__)) && \
defined(PYTORCH_ENABLE_NATIVE_HALF)
return __half{a} + __half{b};
#else
return static_cast<float>(a) + static_cast<float>(b);
#endif
}

inline C10_HOST_DEVICE Half operator-(const Half& a, const Half& b) {
Expand Down
17 changes: 15 additions & 2 deletions cmake/Dependencies.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1000,6 +1000,10 @@ if(USE_CUDNN)
target_include_directories(torch::cudnn INTERFACE ${CUDNN_FRONTEND_INCLUDE_DIR})
endif()

# Note: This variable also affects CUDA as well.
set(PYTORCH_ENABLE_HALF $ENV{PYTORCH_ENABLE_HALF} CACHE BOOL "Native support for half data type from EnVar." FORCE)


# ---[ HIP
if(USE_ROCM)
# This prevents linking in the libtinfo from /opt/conda/lib which conflicts with ROCm libtinfo.
Expand Down Expand Up @@ -1042,7 +1046,11 @@ if(USE_ROCM)
list(APPEND HIP_CXX_FLAGS -D__HIP_PLATFORM_AMD__=1)
list(APPEND HIP_CXX_FLAGS -DCUDA_HAS_FP16=1)
list(APPEND HIP_CXX_FLAGS -DUSE_ROCM)
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1)
if(NOT PYTORCH_ENABLE_HALF)
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_OPERATORS__=1)
else()
add_definitions(-DPYTORCH_ENABLE_NATIVE_HALF)
endif()
list(APPEND HIP_CXX_FLAGS -D__HIP_NO_HALF_CONVERSIONS__=1)
list(APPEND HIP_CXX_FLAGS -DTORCH_HIP_VERSION=${TORCH_HIP_VERSION})
list(APPEND HIP_CXX_FLAGS -Wno-shift-count-negative)
Expand Down Expand Up @@ -1369,11 +1377,16 @@ if(NOT INTERN_BUILD_MOBILE)

message(STATUS "Found CUDA with FP16 support, compiling with torch.cuda.HalfTensor")
string(APPEND CMAKE_CUDA_FLAGS " -DCUDA_HAS_FP16=1"
" -D__CUDA_NO_HALF_OPERATORS__"
" -D__CUDA_NO_HALF_CONVERSIONS__"
" -D__CUDA_NO_HALF2_OPERATORS__"
" -D__CUDA_NO_BFLOAT16_CONVERSIONS__")

if(NOT PYTORCH_ENABLE_HALF)
string(APPEND CMAKE_CUDA_FLAGS " -D__CUDA_NO_HALF_OPERATORS__")
else()
add_definitions(-DPYTORCH_ENABLE_NATIVE_HALF)
endif()

string(APPEND CMAKE_C_FLAGS_RELEASE " -DNDEBUG")
string(APPEND CMAKE_CXX_FLAGS_RELEASE " -DNDEBUG")
if(NOT GENERATOR_IS_MULTI_CONFIG)
Expand Down
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@
#
# BUILD_PYTHON_ONLY
# Builds pytorch as a wheel using libtorch.so from a seperate wheel
#
# PYTORCH_ENABLE_HALF
# If set to '1' will enable native support for FP16 datatypes in certain functors.
# Note: Currently, this is considered experimental and will only affect reductions.

import os
import pkgutil
Expand Down Expand Up @@ -676,6 +680,11 @@ def run(self):
else:
report("-- Not using ITT")

if cmake_cache_vars["PYTORCH_ENABLE_HALF"]:
report("-- Using native FP16 support")
else:
report("-- Not using native FP16 support")

# Do not use clang to compile extensions if `-fstack-clash-protection` is defined
# in system CFLAGS
c_flags = str(os.getenv("CFLAGS", ""))
Expand Down